Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix allgather ops inside cuda graphs #3709

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

nvcastet
Copy link

Fixes #3424

TLDR:
sglang uses pynccl or their customalllreduce instead of pytorch ProcessGroup when graph capturing.
Allgather in sglang code base directly uses pytorch allgather instead of the sglang abstraction to decide which implementation to pick. Allgather is used in DP-attention and also to gather logits across TP dim.
The fix is to perform the allgather via the abstraction so that the same NCCL communicator won't be used inside and outside graph captures.

Motivation

Modifications

Checklist

yizhang2077
yizhang2077 previously approved these changes Feb 20, 2025
Copy link
Collaborator

@yizhang2077 yizhang2077 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cc @zhyncs @ispobock

Copy link
Collaborator

@yizhang2077 yizhang2077 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems cuda graph capture failed in one test.

@Superskyyy
Copy link

Seeing the same issue, please merge soon :)

@zhyncs
Copy link
Member

zhyncs commented Feb 20, 2025

It seems cuda graph capture failed.

Hi @nvcastet May you help take a look? Thanks!

@nvcastet
Copy link
Author

Looking.
@zhyncs Do you have the repro command to reproduce the CI failure?

@yizhang2077
Copy link
Collaborator

Looking. @zhyncs Do you have the repro command to reproduce the CI failure?

Hi @nvcastet , this is command python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_torch_compile_tp2_bs1

@yizhang2077
Copy link
Collaborator

yizhang2077 commented Feb 20, 2025

Hi @nvcastet , I see cuda graph capture failed when use pynccl all gather in logits layer, but it will success in dp attention since test_dp_attention.py has passed. If it is hard to address ci failed issue, I think one of the optional solution is we only use pynccl all gather in dp attention layer, and other all gather places keep the same as previous. Do you think this method can work?

@nvcastet
Copy link
Author

@yizhang2077 Still debugging. Can I have another hour-ish?
Yes it can work but i have seen people complaining about this issue for non-dp-attention configs too.
For the non-dp config paths, we are just getting lucky that it works with cuda graph and the allgather in logits since technically we should use a separate NCCL TP communicator when doing both eager and graph in the same run.

@yizhang2077
Copy link
Collaborator

@yizhang2077 Still debugging. Can I have another hour-ish? Yes it can work but i have seen people complaining about this issue for non-dp-attention configs too. For the non-dp config paths, we are just getting lucky that it works with cuda graph and the allgather in logits since technically we should use a separate NCCL TP communicator when doing both eager and graph in the same run.

Ok, I think it is reasonable

@nvcastet
Copy link
Author

@yizhang2077 Is it reproducible on your side?
I used one of my container images based on commit 9f635ea on a 8xH200 box and could not reproduce the failure with or without my PR applied to it.

@yizhang2077
Copy link
Collaborator

yizhang2077 commented Feb 20, 2025

@yizhang2077 Is it reproducible on your side? I used one of my container images based on commit 9f635ea on a 8xH200 box and could not reproduce the failure with or without my PR applied to it.

It is weird, I will try it on my env.

@zhyncs
Copy link
Member

zhyncs commented Feb 20, 2025

@nvcastet @yizhang2077
https://github.com/sgl-project/sglang/actions/runs/13432619675/job/37556327516?pr=3709
CUDA graph capture failure occurs on the CI, btw CI uses CUDA 12.1

@yizhang2077
Copy link
Collaborator

yizhang2077 commented Feb 20, 2025

Hi @nvcastet, I can reproduce this error in my env. torch 2.5.1+cu124, 8xH200, docker image lmsysorg/sglang:dev

@nvcastet
Copy link
Author

nvcastet commented Feb 20, 2025

@yizhang2077 Thanks let me try your container image.
Might be related:

CUDA RNG operations are permitted, and when using multiple torch.Generator instances within a graph, they must be registered using CUDAGraph.register_generator_state before graph capture. Avoid using Generator.get_state and Generator.set_state during capture; instead, utilize Generator.graphsafe_set_state and Generator.graphsafe_get_state for managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.

From https://pytorch.org/docs/stable/notes/cuda.html#constraints

@nvcastet
Copy link
Author

The failure did not happen for me on pytorch 2.7 but it does on 2.5.
While debugging torch.compile:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py:725: UserWarning: Graph break due to unsupported builtin None._SimpleCData.__new__. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
[rank0]:V0220 21:52:05.352000 21617 torch/_dynamo/guards.py:2813] [29/1] [__recompiles] Recompiling function torch_dynamo_resume_in_all_gather_at_163 in /sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py:163
[rank0]:V0220 21:52:05.352000 21617 torch/_dynamo/guards.py:2813] [29/1] [__recompiles]     triggered by the following guard failure(s):
[rank0]:V0220 21:52:05.352000 21617 torch/_dynamo/guards.py:2813] [29/1] [__recompiles]     - 29/0: L['___stack2'] == 139835979366400

which points to pynccl not being graphable (breaking the torch.compile graph) and also being recompiled later on during cuda graph capture.

I added a commit to remove pynccl and replace it with a processgroup clone which should behave the same way.

@yizhang2077
Copy link
Collaborator

@nvcastet It is wierd, since pynccl allreduce is also in critical path and is graphable.

@ispobock
Copy link
Collaborator

ispobock commented Feb 21, 2025

Hi @nvcastet , pynccl cannot be removed since it's used for cu118 environment. PyTorch cu118 will install the nvidia-nccl-cu11 which with cuda 11.0 version of nccl by default. But cuda graph needs nccl with cuda version > 11.3. So we can specify the nccl so path SGLANG_NCCL_SO_PATH (with cu 11.3 higher support) and use the pynccl warpper to make it compatible with cuda graph.

@nvcastet
Copy link
Author

nvcastet commented Feb 21, 2025

UserWarning: Graph break due to unsupported builtin None._SimpleCData.new. ...
@nvcastet It is wierd, since pynccl allreduce is also in critical path and is graphable.

The message (displayed using TORCH_LOGS="recompiles") here is not about cuda graph, but torch.compile graph.

pynccl cannot be removed since it's used for cu118 environment. PyTorch cu118 will install the nvidia-nccl-cu11 which with cuda 11.0 version of nccl by default. But cuda graph needs nccl with cuda version > 11.3. So we can specify the nccl so path SGLANG_NCCL_SO_PATH (with cu 11.3 higher support) and use the pynccl warpper to make it compatible with cuda graph.

@ispobock Thank you for pointing that out, I was not aware of this issue.

@nvcastet
Copy link
Author

@ispobock when downloading nvidia-nccl-cu11, I see cu116:

# pip download nvidia-nccl-cu11
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com/
Collecting nvidia-nccl-cu11
  Downloading https://developer.download.nvidia.com/compute/redist/nvidia-nccl-cu11/nvidia-nccl-cu11-2022.5.19.tar.gz (16 kB)
  Preparing metadata (setup.py) ... done
Collecting nvidia-nccl-cu116 (from nvidia-nccl-cu11)
  Downloading https://developer.download.nvidia.com/compute/redist/nvidia-nccl-cu116/nvidia_nccl_cu116-2.12.12-py3-none-manylinux1_x86_64.whl (164.8 MB)

So we should be good, right?

@nvcastet nvcastet changed the title Fix allgather ops inside cuda graphs using the pynccl communicator Fix allgather ops inside cuda graphs Feb 21, 2025
@yizhang2077
Copy link
Collaborator

yizhang2077 commented Feb 22, 2025

Hi @nvcastet, I think currently the main problem is torch compile is conflicted with pynccl, right? I think you can try to wrap allgather ops like this way,

direct_register_custom_op(
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
mutates_args=[],
fake_impl=outplace_all_reduce_fake,
)
which can help pynccl get rid of torch compile. (custom operator). I try to write it like this way on my env, it can pass python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_torch_compile_tp2_bs1. Could you take a try?

@desertchen
Copy link

@nvcastet thx,it works!

@robscc
Copy link

robscc commented Feb 23, 2025

I have patched and tested this pr but Watchdog caught collective operation timeout problems still have
here is the log output

[rank4]:[E223 11:14:34.991762039 ProcessGroupNCCL.cpp:616] [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600037 milliseconds before timing out.
ected all trees
sglang-head-r1:105:81719 [4] NCCL INFO threadThresholds 8/8/64 | 128/8/64 | 512 | 512
sglang-head-r1:105:81719 [4] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 2 p2p channels per peer
sglang-head-r1:105:81719 [4] NCCL INFO ncclCommSplit comm 0x7eefecbeb3b0 rank 4 nranks 16 cudaDev 4 nvmlDev 4 busId 69020 parent 0x564fe1225280 color 1197013201 key 4 commId 0x5e372fd9a36b237d - Init COMPLETE
[rank4]:[E223 11:14:34.991860392 ProcessGroupNCCL.cpp:1785] [PG ID 2 PG GUID 3 Rank 4] Exception (either an error or timeout) detected by watchdog at work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank4]:[E223 11:14:34.991866729 ProcessGroupNCCL.cpp:1834] [PG ID 2 PG GUID 3 Rank 4] Timeout at NCCL work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank4]:[E223 11:14:34.991871518 ProcessGroupNCCL.cpp:630] [Rank 4] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank4]:[E223 11:14:34.991874376 ProcessGroupNCCL.cpp:636] [Rank 4] To avoid data inconsistency, we are taking the entire process down.
[rank4]:[E223 11:14:34.993156163 ProcessGroupNCCL.cpp:1595] [PG ID 2 PG GUID 3 Rank 4] Process group watchdog thread terminated with exception: [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600037 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f0895d6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f084bc2a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f084bc31bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f084bc3361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f08979125c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7f089879aac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7f089882c850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 2 PG GUID 3 Rank 4] Process group watchdog thread terminated with exception: [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600037 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f0895d6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f084bc2a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f084bc31bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f084bc3361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f08979125c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7f089879aac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7f089882c850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f0895d6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x7f084b8a071b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7f08979125c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x94ac3 (0x7f089879aac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7f089882c850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007eef92de6640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 462 in watchdog_thread
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007eef91ce4640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 527 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 771 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 833 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 872 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 761 in forward_extend
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 796 in forward
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 164 in forward_batch_generation
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 140 in forward_thread_func_
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 109 in forward_thread_func
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f0398fcd640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f0397fcb640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f08987054c0 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 320 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 171 in resolve_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1257 in process_batch_result_decode
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1116 in process_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 519 in event_loop_overlap
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1825 in run_scheduler_process
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108 in run
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314 in _bootstrap
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 129 in _main
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 116 in spawn_main
  File "<string>", line 1 in <module>

Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, psutil._psutil_linux, psutil._psutil_posix, setproctitle, zmq.backend.cython._zmq, yaml._yaml, markupsafe._speedups, PIL._imaging, PIL._imagingft, msgspec._core, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, msgpack._cmsgpack, google._upb._message, ray._raylet, sentencepiece._sentencepiece, regex._regex, cuda_utils, __triton_launcher (total: 52)
s 8/8/64 | 128/8/64 | 512 | 512
sglang-head-r1:102:81761 [1] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 2 p2p channels per peer
[rank1]:[E223 11:14:34.002784259 ProcessGroupNCCL.cpp:616] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600048 milliseconds before timing out.
sglang-head-r1:102:81761 [1] NCCL INFO ncclCommSplit comm 0x7eebd8212bc0 rank 1 nranks 16 cudaDev 1 nvmlDev 1 busId 65030 parent 0x55e9a3326a40 color 1197013201 key 1 commId 0x5e372fd9a36b237d - Init COMPLETE
[rank1]:[E223 11:14:34.002852203 ProcessGroupNCCL.cpp:1785] [PG ID 2 PG GUID 3 Rank 1] Exception (either an error or timeout) detected by watchdog at work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank1]:[E223 11:14:34.002858978 ProcessGroupNCCL.cpp:1834] [PG ID 2 PG GUID 3 Rank 1] Timeout at NCCL work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank1]:[E223 11:14:34.002862501 ProcessGroupNCCL.cpp:630] [Rank 1] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank1]:[E223 11:14:34.002864978 ProcessGroupNCCL.cpp:636] [Rank 1] To avoid data inconsistency, we are taking the entire process down.
[rank1]:[E223 11:14:34.004091302 ProcessGroupNCCL.cpp:1595] [PG ID 2 PG GUID 3 Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600048 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f047a56c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f043042a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f0430431bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f043043361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f047c1f45c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7f047d07cac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7f047d10e850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 2 PG GUID 3 Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600048 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f047a56c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f043042a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f0430431bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f043043361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f047c1f45c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7f047d07cac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7f047d10e850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f047a56c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x7f04300a071b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7f047c1f45c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x94ac3 (0x7f047d07cac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7f047d10e850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007eeb7e5fc640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 462 in watchdog_thread
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007eeb7edfd640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 527 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 771 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 833 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 872 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 761 in forward_extend
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 796 in forward
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 164 in forward_batch_generation
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 140 in forward_thread_func_
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 109 in forward_thread_func
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007eff787c8640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007eff777c6640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f047cfe74c0 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 320 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 171 in resolve_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1257 in process_batch_result_decode
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1116 in process_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 519 in event_loop_overlap
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1825 in run_scheduler_process
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108 in run
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314 in _bootstrap
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 129 in _main
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 116 in spawn_main
  File "<string>", line 1 in <module>

Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, psutil._psutil_linux, psutil._psutil_posix, setproctitle, zmq.backend.cython._zmq, yaml._yaml, markupsafe._speedups, PIL._imaging, PIL._imagingft, msgspec._core, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, msgpack._cmsgpack, google._upb._message, ray._raylet, sentencepiece._sentencepiece, regex._regex, cuda_utils, __triton_launcher (total: 52)
[rank2]:[E223 11:14:34.006107122 ProcessGroupNCCL.cpp:616] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600052 milliseconds before timing out.
trees
sglang-head-r1:103:81765 [2] NCCL INFO threadThresholds 8/8/64 | 128/8/64 | 512 | 512
sglang-head-r1:103:81765 [2] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 2 p2p channels per peer
sglang-head-r1:103:81765 [2] NCCL INFO ncclCommSplit comm 0x7efa9ca4f9e0 rank 2 nranks 16 cudaDev 2 nvmlDev 2 busId 67020 parent 0x55adc52e7d90 color 1197013201 key 2 commId 0x5e372fd9a36b237d - Init COMPLETE
[rank2]:[E223 11:14:34.006211132 ProcessGroupNCCL.cpp:1785] [PG ID 2 PG GUID 3 Rank 2] Exception (either an error or timeout) detected by watchdog at work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank2]:[E223 11:14:34.006218798 ProcessGroupNCCL.cpp:1834] [PG ID 2 PG GUID 3 Rank 2] Timeout at NCCL work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank2]:[E223 11:14:34.006225792 ProcessGroupNCCL.cpp:630] [Rank 2] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank2]:[E223 11:14:34.006243662 ProcessGroupNCCL.cpp:636] [Rank 2] To avoid data inconsistency, we are taking the entire process down.
[rank0]:[E223 11:14:34.007493640 ProcessGroupNCCL.cpp:616] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600053 milliseconds before timing out.
nel 06/0 : 0[0] -> 1[1] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 08/0 : 0[0] -> 1[1] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 09/0 : 0[0] -> 1[1] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 10/0 : 0[0] -> 1[1] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 12/0 : 0[0] -> 1[1] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 13/0 : 0[0] -> 1[1] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 14/0 : 0[0] -> 1[1] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 03/0 : 0[0] -> 7[7] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 05/0 : 0[0] -> 7[7] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 07/0 : 0[0] -> 7[7] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 11/0 : 0[0] -> 7[7] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 13/0 : 0[0] -> 7[7] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 15/0 : 0[0] -> 7[7] via P2P/IPC
sglang-head-r1:101:81764 [0] NCCL INFO Channel 00/0 : 8[0] -> 0[0] [receive] via NET/IBext_v8/0
sglang-head-r1:101:81764 [0] NCCL INFO Channel 01/0 : 8[0] -> 0[0] [receive] via NET/IBext_v8/2
sglang-head-r1:101:81764 [0] NCCL INFO Channel 08/0 : 8[0] -> 0[0] [receive] via NET/IBext_v8/0
sglang-head-r1:101:81764 [0] NCCL INFO Channel 09/0 : 8[0] -> 0[0] [receive] via NET/IBext_v8/2
sglang-head-r1:101:81764 [0] NCCL INFO Channel 00/0 : 0[0] -> 8[0] [send] via NET/IBext_v8/0
sglang-head-r1:101:81764 [0] NCCL INFO Channel 01/0 : 0[0] -> 8[0] [send] via NET/IBext_v8/2
sglang-head-r1:101:81764 [0] NCCL INFO Channel 08/0 : 0[0] -> 8[0] [send] via NET/IBext_v8/0
sglang-head-r1:101:81764 [0] NCCL INFO Channel 09/0 : 0[0] -> 8[0] [send] via NET/IBext_v8/2
sglang-head-r1:101:81764 [0] NCCL INFO Connected all trees
sglang-head-r1:101:81764 [0] NCCL INFO threadThresholds 8/8/64 | 128/8/64 | 512 | 512
sglang-head-r1:101:81764 [0] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 2 p2p channels per peer
sglang-head-r1:101:81764 [0] NCCL INFO ncclCommSplit comm 0x7f9e2cb19d20 rank 0 nranks 16 cudaDev 0 nvmlDev 0 busId 65020 parent 0x55f7f715de20 color 1197013201 key 0 commId 0x5e372fd9a36b237d - Init COMPLETE
[rank0]:[E223 11:14:34.007572619 ProcessGroupNCCL.cpp:1785] [PG ID 2 PG GUID 3 Rank 0] Exception (either an error or timeout) detected by watchdog at work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank0]:[E223 11:14:34.007581928 ProcessGroupNCCL.cpp:1834] [PG ID 2 PG GUID 3 Rank 0] Timeout at NCCL work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank0]:[E223 11:14:34.007586917 ProcessGroupNCCL.cpp:630] [Rank 0] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank0]:[E223 11:14:34.007590240 ProcessGroupNCCL.cpp:636] [Rank 0] To avoid data inconsistency, we are taking the entire process down.
[rank2]:[E223 11:14:34.007651635 ProcessGroupNCCL.cpp:1595] [PG ID 2 PG GUID 3 Rank 2] Process group watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600052 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f164c56c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f160242a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f1602431bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f160243361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f164e16c5c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7f164eff4ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7f164f086850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 2 PG GUID 3 Rank 2] Process group watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600052 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f164c56c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f160242a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f1602431bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f160243361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f164e16c5c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7f164eff4ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7f164f086850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f164c56c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x7f16020a071b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7f164e16c5c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x94ac3 (0x7f164eff4ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7f164f086850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007efd936fb640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 462 in watchdog_thread
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007efd93ffc640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 527 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 771 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 833 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 872 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 761 in forward_extend
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 796 in forward
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 164 in forward_batch_generation
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 140 in forward_thread_func_
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 109 in forward_thread_func
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f114afc5640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f1149fc3640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f164ef5f4c0 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 320 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 171 in resolve_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1257 in process_batch_result_decode
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1116 in process_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 519 in event_loop_overlap
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1825 in run_scheduler_process
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108 in run
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314 in _bootstrap
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 129 in _main
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 116 in spawn_main
  File "<string>", line 1 in <module>

Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested[rank0]:[E223 11:14:34.009113793 ProcessGroupNCCL.cpp:1595] [PG ID 2 PG GUID 3 Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600053 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fb72e96c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7fb6e482a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7fb6e4831bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7fb6e483361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7fb73060c5c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7fb731494ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7fb731526850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

, torch._C._nn, torch._C._sparse, terminate called after throwing an instance of 'torch._C._specialc10::DistBackendError'
, psutil._psutil_linux, psutil._psutil_posix, setproctitle, zmq.backend.cython._zmq, yaml._yaml, markupsafe._speedups, PIL._imaging, PIL._imagingft, msgspec._core, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, msgpack._cmsgpack, google._upb._message, ray._raylet, sentencepiece._sentencepiece, regex._regex, cuda_utils, __triton_launcher  what():  [PG ID 2 PG GUID 3 Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600053 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fb72e96c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7fb6e482a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7fb6e4831bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7fb6e483361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7fb73060c5c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7fb731494ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7fb731526850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fb72e96c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x7fb6e44a071b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7fb73060c5c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x94ac3 (0x7fb731494ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7fb731526850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007f9e5b7f3640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 462 in watchdog_thread
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f9e5bff4640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 527 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File  (total: "52/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py)"
, line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 771 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 833 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 872 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 761 in forward_extend
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 796 in forward
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 164 in forward_batch_generation
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 140 in forward_thread_func_
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 109 in forward_thread_func
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fb218fc5640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fb2187c4640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fb7313ff4c0 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 320 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 171 in resolve_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1257 in process_batch_result_decode
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1116 in process_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 519 in event_loop_overlap
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1825 in run_scheduler_process
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108 in run
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314 in _bootstrap
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 129 in _main
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 116 in spawn_main
  File "<string>", line 1 in <module>

Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, psutil._psutil_linux, psutil._psutil_posix, setproctitle, zmq.backend.cython._zmq, yaml._yaml, markupsafe._speedups, PIL._imaging, PIL._imagingft, msgspec._core, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, msgpack._cmsgpack, google._upb._message, ray._raylet, sentencepiece._sentencepiece, regex._regex, cuda_utils, __triton_launcher (total: 52)
[rank7]:[E223 11:14:34.011293427 ProcessGroupNCCL.cpp:616] [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600057 milliseconds before timing out.
ected all trees
sglang-head-r1:108:81763 [7] NCCL INFO threadThresholds 8/8/64 | 128/8/64 | 512 | 512
sglang-head-r1:108:81763 [7] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 2 p2p channels per peer
sglang-head-r1:108:81763 [7] NCCL INFO ncclCommSplit comm 0x7f553471c140 rank 7 nranks 16 cudaDev 7 nvmlDev 7 busId 6b030 parent 0x564182ff0cb0 color 1197013201 key 7 commId 0x5e372fd9a36b237d - Init COMPLETE
[rank7]:[E223 11:14:34.011376745 ProcessGroupNCCL.cpp:1785] [PG ID 2 PG GUID 3 Rank 7] Exception (either an error or timeout) detected by watchdog at work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank7]:[E223 11:14:34.011381427 ProcessGroupNCCL.cpp:1834] [PG ID 2 PG GUID 3 Rank 7] Timeout at NCCL work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank7]:[E223 11:14:34.011384560 ProcessGroupNCCL.cpp:630] [Rank 7] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank7]:[E223 11:14:34.011386790 ProcessGroupNCCL.cpp:636] [Rank 7] To avoid data inconsistency, we are taking the entire process down.
[rank7]:[E223 11:14:34.012705754 ProcessGroupNCCL.cpp:1595] [PG ID 2 PG GUID 3 Rank 7] Process group watchdog thread terminated with exception: [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600057 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f6f7916c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f6f2f02a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f6f2f031bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f6f2f03361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f6f7ae1d5c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7f6f7bca5ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7f6f7bd37850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 2 PG GUID 3 Rank 7] Process group watchdog thread terminated with exception: [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600057 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f6f7916c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f6f2f02a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f6f2f031bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f6f2f03361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f6f7ae1d5c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7f6f7bca5ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7f6f7bd37850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f6f7916c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x7f6f2eca071b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7f6f7ae1d5c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x94ac3 (0x7f6f7bca5ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7f6f7bd37850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007f56b81e9640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 462 in watchdog_thread
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f56b72e7640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 527 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 771 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 833 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 872 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 761 in forward_extend
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 796 in forward
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 164 in forward_batch_generation
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 140 in forward_thread_func_
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 109 in forward_thread_func
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f6a7c7c8640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f6a7b7c6640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f6f7bc104c0 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 320 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 171 in resolve_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1257 in process_batch_result_decode
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1116 in process_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 519 in event_loop_overlap
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1825 in run_scheduler_process
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108 in run
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314 in _bootstrap
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 129 in _main
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 116 in spawn_main
  File "<string>", line 1 in <module>

Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, psutil._psutil_linux, psutil._psutil_posix, setproctitle, zmq.backend.cython._zmq, yaml._yaml, markupsafe._speedups, PIL._imaging, PIL._imagingft[rank6]:[E223 11:14:34.013889302 ProcessGroupNCCL.cpp:616] [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600060 milliseconds before timing out.
 5[5] via P2P/IPC
sglang-head-r1:107:81792 [6] NCCL INFO Channel 07/0 : 6[6] -> 5[5] via P2P/IPC
sglang-head-r1:107:81792 [6] NCCL INFO Channel 09/0 : 6[6] -> 5[5] via P2P/IPC
sglang-head-r1:107:81792 [6] NCCL INFO Channel 11/0 : 6[6] -> 5[5] via P2P/IPC
sglang-head-r1:107:81792 [6] NCCL INFO Channel 13/0 : 6[6] -> 5[5] via P2P/IPC
sglang-head-r1:107:81792 [6] NCCL INFO Channel 15/0 : 6[6] -> 5[5] via P2P/IPC
sglang-head-r1:107:81792 [6] NCCL INFO Connected all trees
sglang-head-r1:107:81792 [6] NCCL INFO threadThresholds 8/8/64 | 128/8/64 | 512 | 512
sglang-head-r1:107:81792 [6] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 2 p2p channels per peer
sglang-head-r1:107:81792 [6] NCCL INFO ncclCommSplit comm 0x7fc8202f97f0 rank 6 nranks 16 cudaDev 6 nvmlDev 6 busId 6b020 parent 0x55cea32bf060 color 1197013201 key 6 commId 0x5e372fd9a36b237d - Init COMPLETE
[rank6]:[E223 11:14:34.013951724 ProcessGroupNCCL.cpp:1785] [PG ID 2 PG GUID 3 Rank 6] Exception (either an error or timeout) detected by watchdog at work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
, msgspec._core[rank6]:[E223 11:14:34.013957617 ProcessGroupNCCL.cpp:1834] [PG ID 2 PG GUID 3 Rank 6] Timeout at NCCL work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank6]:[E223 11:14:34.013961018 ProcessGroupNCCL.cpp:630] [Rank 6] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
, [rank6]:[E223 11:14:34.013963384 ProcessGroupNCCL.cpp:636] [Rank 6] To avoid data inconsistency, we are taking the entire process down.
multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, msgpack._cmsgpack, google._upb._message, ray._raylet, sentencepiece._sentencepiece, regex._regex, cuda_utils, __triton_launcher (total: 52)
[rank6]:[E223 11:14:34.015377033 ProcessGroupNCCL.cpp:1595] [PG ID 2 PG GUID 3 Rank 6] Process group watchdog thread terminated with exception: [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600060 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fe40116c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7fe3b702a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7fe3b7031bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7fe3b703361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7fe402da35c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7fe403c2bac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7fe403cbd850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 2 PG GUID 3 Rank 6] Process group watchdog thread terminated with exception: [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600060 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fe40116c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7fe3b702a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7fe3b7031bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7fe3b703361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7fe402da35c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7fe403c2bac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7fe403cbd850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fe40116c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x7fe3b6ca071b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7fe402da35c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x94ac3 (0x7fe403c2bac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7fe403cbd850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007fcafbce8640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 462 in watchdog_thread
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fcafcbe9640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 527 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 771 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 833 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 872 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 761 in forward_extend
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 796 in forward
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 164 in forward_batch_generation
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 140 in forward_thread_func_
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 109 in forward_thread_func
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fdefffc7640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fdefefc5640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fe403b964c0 (most recent call first):
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 2425 in broadcast
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 83 in wrapper
  File "/sgl-workspace/sglang/python/sglang/srt/utils.py", line 693 in broadcast_pyobj
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 575 in recv_requests
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 493 in event_loop_overlap
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1825 in run_scheduler_process
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108 in run
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314 in _bootstrap
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 129 in _main
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 116 in spawn_main
  File "<string>", line 1 in <module>

Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, psutil._psutil_linux, psutil._psutil_posix, setproctitle, zmq.backend.cython._zmq, yaml._yaml[rank3]:[E223 11:14:34.016498337 ProcessGroupNCCL.cpp:616] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600062 milliseconds before timing out.
P2P/IPC
sglang-head-r1:104:81762 [3] NCCL INFO Channel 10/0 : 3[3] -> 2[2] via P2P/IPC
sglang-head-r1:104:81762 [3] NCCL INFO Channel 11/0 : 3[3] -> 2[2] via P2P/IPC
sglang-head-r1:104:81762 [3] NCCL INFO Channel 13/0 : 3[3] -> 2[2] via P2P/IPC
sglang-head-r1:104:81762 [3] NCCL INFO Channel 14/0 : 3[3] -> 2[2] via P2P/IPC
sglang-head-r1:104:81762 [3] NCCL INFO Channel 15/0 : 3[3] -> 2[2] via P2P/IPC
sglang-head-r1:104:81762 [3] NCCL INFO Connected all trees
sglang-head-r1:104:81762 [3] NCCL INFO threadThresholds 8/8/64 | 128/8/64 | 512 | 512
sglang-head-r1:104:81762 [3] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 2 p2p channels per peer
sglang-head-r1:104:81762 [3] NCCL INFO ncclCommSplit comm 0x7fdafc983cd0 rank 3 nranks 16 cudaDev 3 nvmlDev 3 busId 67030 parent 0x55f4d85d4f20 color 1197013201 key 3 commId 0x5e372fd9a36b237d - Init COMPLETE
, markupsafe._speedups, PIL._imaging, PIL._imagingft[rank3]:[E223 11:14:34.016566616 ProcessGroupNCCL.cpp:1785] [PG ID 2 PG GUID 3 Rank 3] Exception (either an error or timeout) detected by watchdog at work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank3]:[E223 11:14:34.016573781 ProcessGroupNCCL.cpp:1834] [PG ID 2 PG GUID 3 Rank 3] Timeout at NCCL work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank3]:[E223 11:14:34.016577443 ProcessGroupNCCL.cpp:630] [Rank 3] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank3]:[E223 11:14:34.016579895 ProcessGroupNCCL.cpp:636] [Rank 3] To avoid data inconsistency, we are taking the entire process down.
, msgspec._core, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, msgpack._cmsgpack, google._upb._message, ray._raylet, sentencepiece._sentencepiece, regex._regex, cuda_utils, __triton_launcher (total: 52)
[rank3]:[E223 11:14:34.018194662 ProcessGroupNCCL.cpp:1595] [PG ID 2 PG GUID 3 Rank 3] Process group watchdog thread terminated with exception: [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600062 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7ff572b6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7ff528a2a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7ff528a31bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7ff528a3361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7ff5747b85c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7ff575640ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7ff5756d2850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 2 PG GUID 3 Rank 3] Process group watchdog thread terminated with exception: [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600062 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7ff572b6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7ff528a2a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7ff528a31bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7ff528a3361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7ff5747b85c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7ff575640ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7ff5756d2850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7ff572b6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x7ff5286a071b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7ff5747b85c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x94ac3 (0x7ff575640ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7ff5756d2850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007fdc66ffe640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 462 in watchdog_thread
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fdc677ff640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 527 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 771 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 833 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 872 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 761 in forward_extend
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 796 in forward
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 164 in forward_batch_generation
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 140 in forward_thread_func_
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 109 in forward_thread_func
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007ff06bfcf640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007ff06afcd640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007ff5755ab4c0 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 320 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 171 in resolve_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1257 in process_batch_result_decode
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1116 in process_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 519 in event_loop_overlap
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1825 in run_scheduler_process
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108 in run
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314 in _bootstrap
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 129 in _main
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 116 in spawn_main
  File "<string>", line 1 in <module>

Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, psutil._psutil_linux, psutil._psutil_posix, setproctitle, zmq.backend.cython._zmq, yaml._yaml, markupsafe._speedups, PIL._imaging, PIL._imagingft, msgspec._core, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, msgpack._cmsgpack, google._upb._message, ray._raylet, sentencepiece._sentencepiece, regex._regex, cuda_utils, __triton_launcher (total: 52)
[rank5]:[E223 11:14:34.020664693 ProcessGroupNCCL.cpp:616] [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600066 milliseconds before timing out.
 4[4] via P2P/IPC
sglang-head-r1:106:81711 [5] NCCL INFO Channel 07/0 : 5[5] -> 4[4] via P2P/IPC
sglang-head-r1:106:81711 [5] NCCL INFO Channel 09/0 : 5[5] -> 4[4] via P2P/IPC
sglang-head-r1:106:81711 [5] NCCL INFO Channel 11/0 : 5[5] -> 4[4] via P2P/IPC
sglang-head-r1:106:81711 [5] NCCL INFO Channel 13/0 : 5[5] -> 4[4] via P2P/IPC
sglang-head-r1:106:81711 [5] NCCL INFO Channel 15/0 : 5[5] -> 4[4] via P2P/IPC
sglang-head-r1:106:81711 [5] NCCL INFO Connected all trees
sglang-head-r1:106:81711 [5] NCCL INFO threadThresholds 8/8/64 | 128/8/64 | 512 | 512
sglang-head-r1:106:81711 [5] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 2 p2p channels per peer
sglang-head-r1:106:81711 [5] NCCL INFO ncclCommSplit comm 0x7fad6021d730 rank 5 nranks 16 cudaDev 5 nvmlDev 5 busId 69030 parent 0x561fa89e54d0 color 1197013201 key 5 commId 0x5e372fd9a36b237d - Init COMPLETE
[rank5]:[E223 11:14:34.020750938 ProcessGroupNCCL.cpp:1785] [PG ID 2 PG GUID 3 Rank 5] Exception (either an error or timeout) detected by watchdog at work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank5]:[E223 11:14:34.020759159 ProcessGroupNCCL.cpp:1834] [PG ID 2 PG GUID 3 Rank 5] Timeout at NCCL work: 91885, last enqueued NCCL work: 91885, last completed NCCL work: 91884.
[rank5]:[E223 11:14:34.020763805 ProcessGroupNCCL.cpp:630] [Rank 5] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank5]:[E223 11:14:34.020766902 ProcessGroupNCCL.cpp:636] [Rank 5] To avoid data inconsistency, we are taking the entire process down.
[rank5]:[E223 11:14:34.022282065 ProcessGroupNCCL.cpp:1595] [PG ID 2 PG GUID 3 Rank 5] Process group watchdog thread terminated with exception: [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600066 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fc5e6d6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7fc59cc2a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7fc59cc31bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7fc59cc3361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7fc5e89685c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7fc5e97f0ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7fc5e9882850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 2 PG GUID 3 Rank 5] Process group watchdog thread terminated with exception: [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=91885, OpType=ALLREDUCE, NumelIn=1290240, NumelOut=1290240, Timeout(ms)=600000) ran for 600066 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fc5e6d6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7fc59cc2a772 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7fc59cc31bb3 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7fc59cc3361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7fc5e89685c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x94ac3 (0x7fc5e97f0ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x126850 (0x7fc5e9882850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fc5e6d6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x7fc59c8a071b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7fc5e89685c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x94ac3 (0x7fc5e97f0ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7fc5e9882850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007facca5fd640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 462 in watchdog_thread
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007facc96fa640 (most recent call first):
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 527 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 771 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 833 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 872 in forward
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 761 in forward_extend
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 796 in forward
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 164 in forward_batch_generation
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 140 in forward_thread_func_
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 109 in forward_thread_func
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fc0e67cc640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fc0e57ca640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fc5e975b4c0 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 320 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 171 in resolve_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1257 in process_batch_result_decode
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1116 in process_batch_result
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 519 in event_loop_overlap
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1825 in run_scheduler_process
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108 in run
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314 in _bootstrap
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 129 in _main
  File "/usr/lib/python3.10/multiprocessing/spawn.py", line 116 in spawn_main
  File "<string>", line 1 in <module>

Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, psutil._psutil_linux, psutil._psutil_posix, setproctitle, zmq.backend.cython._zmq, yaml._yaml, markupsafe._speedups, PIL._imaging, PIL._imagingft, msgspec._core, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, msgpack._cmsgpack, google._upb._message, ray._raylet, sentencepiece._sentencepiece, regex._regex, cuda_utils, __triton_launcher (total: 52)

@nvcastet nvcastet force-pushed the fix_all_gather_cuda_graph branch from 48a926f to 38585e0 Compare February 24, 2025 20:26
@nvcastet
Copy link
Author

@yizhang2077 @zhyncs I went back to the first commit and register the pynccl algather as a pytorch custom op as you suggested.
Ideally, it would be nice to get rid of pynccl if there is not a need for it anymore. But that can be done in another PR.

@yizhang2077 yizhang2077 self-requested a review February 25, 2025 01:54
Copy link
Collaborator

@yizhang2077 yizhang2077 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, hi @nvcastet could you format your code (ci failed on lint), @ispobock @zhyncs could you take a look?

@robscc
Copy link

robscc commented Feb 25, 2025

@yizhang2077 @zhyncs I went back to the first commit and register the pynccl algather as a pytorch custom op as you suggested. Ideally, it would be nice to get rid of pynccl if there is not a need for it anymore. But that can be done in another PR.

Could you do me a favor? i have merge ur pr and the issue reproduce . Many thanks

@ispobock
Copy link
Collaborator

@ispobock when downloading nvidia-nccl-cu11, I see cu116:

# pip download nvidia-nccl-cu11
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com/
Collecting nvidia-nccl-cu11
  Downloading https://developer.download.nvidia.com/compute/redist/nvidia-nccl-cu11/nvidia-nccl-cu11-2022.5.19.tar.gz (16 kB)
  Preparing metadata (setup.py) ... done
Collecting nvidia-nccl-cu116 (from nvidia-nccl-cu11)
  Downloading https://developer.download.nvidia.com/compute/redist/nvidia-nccl-cu116/nvidia_nccl_cu116-2.12.12-py3-none-manylinux1_x86_64.whl (164.8 MB)

So we should be good, right?

In my env, it will download the whl from https://download.pytorch.org/whl/nvidia-nccl-cu11/ when installing pytorch. It's cu11 not cu116.

@ispobock
Copy link
Collaborator

ispobock commented Feb 25, 2025

@nvcastet Could you fix the lint with pre-commit run --all-files?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] DeepSeek R1 serve crash occasionally on 2*H100
7 participants