Skip to content

Commit

Permalink
Add multi-nodes example & update doc (#455)
Browse files Browse the repository at this point in the history
Documentation update:

*
[`docs/design/mscclpp-dsl.md`](diffhunk://#diff-02a69290fb3e02b8a069bf915fbf5266cfc2ac51c6e9ff8b5b19df51ed909b22L114-R114):
Updated the link to the examples folder to reflect the correct path.

New example script:

*
[`python/examples/allgather_allpairs_multinodes_packets.py`](diffhunk://#diff-ab42c16ecca0680d55b60b82a6913138c5fba4069b9c4493fbe8c72217fe54bcR1-R76):
Added a new example script demonstrating the allgather all-pairs
algorithm across multiple nodes using packet communication.

IR module improvements:

*
[`python/mscclpp/language/ir.py`](diffhunk://#diff-b025796b03fbbd9b2ca9aee2569547efa7a56101743bc4aa05661be0b52aeec9L470-R472):
Refined the sorting criteria for GPU instance channels and thread block
channels to include the channel type, ensuring a more accurate order.
Debugging enhancements:

*
[`src/executor/executor.cc`](diffhunk://#diff-60f7806d111e5cc12ded06358b5d5b09b8521e3858f182d8be81ac05147c535dR439-R441):
Added a debug log to indicate the start of communication collective
execution with details about the execution plan and collective.
*
[`src/include/debug.h`](diffhunk://#diff-24e5fda55e3712277be4bb99b3c348294a77ebd3046bfe716b74bdb32cd203dfR89):
Introduced a new debug log subsystem identifier `MSCCLPP_EXECUTOR` for
logging executor-related information.
  • Loading branch information
Binyang2014 authored Feb 1, 2025
1 parent 3565bfd commit 7f3b088
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 4 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/mscclpp-lang.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ jobs:

steps:
- uses: actions/checkout@v4

- name: Set environment variable
run: echo "LD_LIBRARY_PATH=/usr/local/cuda/compat:/usr/local/cuda/lib64" >> $GITHUB_ENV

- name: Install mscclpp
run: |
CMAKE_ARGS="-DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON" pip3 install .
Expand Down
1 change: 0 additions & 1 deletion docker/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ baseImageTable=(

declare -A extraLdPathTable
extraLdPathTable=(
["cuda11.8"]="/usr/local/cuda-11.8/lib64"
["cuda12.1"]="/usr/local/cuda-12.1/compat:/usr/local/cuda-12.1/lib64"
["cuda12.2"]="/usr/local/cuda-12.2/compat:/usr/local/cuda-12.2/lib64"
["cuda12.3"]="/usr/local/cuda-12.3/compat:/usr/local/cuda-12.3/lib64"
Expand Down
2 changes: 1 addition & 1 deletion docs/design/mscclpp-dsl.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@ Packet APIs are used when user wants to use LL algorithm. The packet APIs are si


### Examples
We provide several examples demonstrating how to use the MSCCL++ DSL to write communication collective algorithms. For more details, please refer to the [examples](https://github.com/microsoft/mscclpp/tree/main/mscclpp-lang/python/examples) folder.
We provide several examples demonstrating how to use the MSCCL++ DSL to write communication collective algorithms. For more details, please refer to the [examples](https://github.com/microsoft/mscclpp/tree/main/python/examples) folder.
74 changes: 74 additions & 0 deletions python/examples/allgather_allpairs_multinodes_packets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllGather
from mscclpp.language.buffer import Buffer
from mscclpp.language.types import ChannelType, ReplicationPolicy


def allgather_multinodes_allpair(gpus, gpus_per_node, instances):
"""
Implements a multi-node allgather collective using an allpairs algorithm with MSCCL++ DSL.
@param gpus: Total number of GPUs
@param gpus_per_node: Number of GPUs per node
Steps:
1. Each rank sends a chunk to all other ranks' scratch buffers using packet format.
2. Copy the chunk from the scratch buffer to the output buffer using packet format.
"""
collective = AllGather(gpus, 1, True)
with MSCCLPPProgram(
"allgather_multinodes_allpair",
collective,
gpus,
instances,
protocol="LL",
replication_policy=ReplicationPolicy.interleaved,
num_threads_per_block=1024,
):
for g in range(gpus):
src_rank = g
c = chunk(src_rank, Buffer.input, 0, 1)
for peer in range(1, gpus):
dst_rank = (src_rank + peer) % gpus
tb = dst_rank if dst_rank < src_rank else dst_rank - 1
if src_rank // gpus_per_node == dst_rank // gpus_per_node:
c.put_packet(dst_rank, Buffer.scratch, index=src_rank, sendtb=tb)
else:
c.put_packet(
dst_rank,
Buffer.scratch,
index=src_rank,
sendtb=tb,
chan_type=ChannelType.port,
temp_buffer=Buffer.scratch,
temp_buffer_index=src_rank,
)

# Copying packet from local scratch buffer to local buffer
for g in range(gpus):
src_rank = g
src_offset = src_rank
for peer in range(1, gpus):
dst_rank = (g + peer) % gpus
tb = src_offset if src_offset < dst_rank else src_offset - 1
c = chunk(dst_rank, Buffer.scratch, src_offset, 1)
c.copy_packet(dst_rank, Buffer.output, src_offset, sendtb=tb + gpus - 1)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("gpus_per_node", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")

args = parser.parse_args()

allgather_multinodes_allpair(
args.num_gpus,
args.gpus_per_node,
args.instances,
)
6 changes: 4 additions & 2 deletions python/mscclpp/language/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ def remove_empty_fields(d):
obj["connectedTo"] = [sorted(list(peers)) for peers in obj["connectedTo"]]
gpu_instance["channels"].append(obj)
gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"]))
gpu_instance["channels"] = sorted(gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"]))
gpu_instance["channels"] = sorted(
gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"], x["type"])
)

# render for GPU NVLS channels
for i, chan in enumerate(gpu_instance["channels"]):
Expand Down Expand Up @@ -502,7 +504,7 @@ def remove_empty_fields(d):
tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj
tb_channels.append(obj)
tb_channels = filter(lambda x: x["type"] != "none", tb_channels)
tb_channels = sorted(tb_channels, key=lambda x: (x["srcbuff"], x["dstbuff"]))
tb_channels = sorted(tb_channels, key=lambda x: (x["srcbuff"], x["dstbuff"], x["type"]))
for op in tb.ops:
if op.tb == -1:
continue
Expand Down
4 changes: 4 additions & 0 deletions python/test/configs/mscclpp_lang_test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,9 @@
{
"filename": "allreduce_nvls.py",
"args": ["8", "2"]
},
{
"filename": "allgather_allpairs_multinodes_packets.py",
"args": ["16", "8", "1"]
}
]
3 changes: 3 additions & 0 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <mscclpp/port_channel.hpp>
#include <set>

#include "debug.h"
#include "execution_kernel.hpp"
#include "execution_plan.hpp"

Expand Down Expand Up @@ -435,6 +436,8 @@ Executor::Executor(std::shared_ptr<Communicator> comm) : impl_(std::make_unique<
void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize,
[[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
cudaStream_t stream, PacketType packetType) {
INFO(MSCCLPP_EXECUTOR, "Starting execution with plan: %s, collective: %s", plan.name().c_str(),
plan.collective().c_str());
size_t sendMemRange, recvMemRange;
CUdeviceptr sendBasePtr, recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendMemRange, (CUdeviceptr)sendbuff));
Expand Down
1 change: 1 addition & 0 deletions src/include/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ typedef enum {
MSCCLPP_ENV = 128,
MSCCLPP_ALLOC = 256,
MSCCLPP_CALL = 512,
MSCCLPP_EXECUTOR = 1024,
MSCCLPP_ALL = ~0
} mscclppDebugLogSubSys;

Expand Down

0 comments on commit 7f3b088

Please sign in to comment.