diff --git a/aiter/ops/triton/moe_op.py b/aiter/ops/triton/moe_op.py new file mode 100644 index 0000000..fa009fd --- /dev/null +++ b/aiter/ops/triton/moe_op.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +from typing import Any, Dict, Optional, List + + +#TODO: add support to set this through env var or a function call by a framework like VLLM/SGLANG +#MOE_PADDING_SIZE = 128 if AITER_TRITON_MOE_PADDING else 0 +MOE_PADDING_SIZE = 0 + +@triton.jit +def _fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Note: this is Triton jited function and not meant to be called directly. Call fused_moe function + below + + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + if use_int8_w8a16: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8: + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + + offs_bsn * stride_bsn) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=token_mask, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +def moe_triton(A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_shape: Optional[List[int]] = None) -> None: + if mul_routed_weight: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if use_fp8_w8a8: + assert B_scale is not None + if block_shape is None: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16: + assert B_scale is not None + else: + assert A_scale is None + assert B_scale is None + + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + + _fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - MOE_PADDING_SIZE, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + diff --git a/op_tests/triton/test_moe.py b/op_tests/triton/test_moe.py new file mode 100644 index 0000000..54a0d6e --- /dev/null +++ b/op_tests/triton/test_moe.py @@ -0,0 +1,196 @@ +import triton +import torch +import triton.language as tl +import pytest +from typing import Any, Dict, Optional +import os +import json +import functools +import argparse +import sys + +from aiter.ops.triton.moe_op import moe_triton + +def torch_moe(a, b, c, topk_ids, topk_weights, routed_weight, sorted_token_ids, expert_ids, num_tokens_post_padded, + a_scale, b_scale, dtype, fp8): + E, N, K = b.shape + M, topk, _ = c.shape + c = c.reshape(-1, c.shape[2]) + + if fp8: + a = a.to(dtype) + + for e in range(E): + token_ids = (topk_ids == e).any(dim=-1) + flat_topk_ids = topk_ids.view(-1) + flat_token_ids = torch.arange(topk_ids.numel(), device=topk_ids.device) + c_token_ids = flat_token_ids[flat_topk_ids == e] + + b_e = b[e] + a_e = a[token_ids, :] + + if fp8: + b_e = b_e.to(dtype) + + acc = torch.matmul(a_e, b_e.T) + if routed_weight: + acc = acc * topk_weights.view(-1)[c_token_ids].unsqueeze(-1) + + if fp8: + acc = (acc * a_scale * b_scale[e]).to(dtype) + + c[c_token_ids, :] = acc + + c = c.reshape(M, topk, N) + + return c + +def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int, block_size: int, + sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + M, top_k = topk_ids.shape + + expert_to_tokens = [[] for _ in range(num_experts)] + # For each token, for each selected expert, we append (token_id, expert) + for token_id in range(M): + for j in range(top_k): + e_id = topk_ids[token_id, j].item() + expert_to_tokens[e_id].append(token_id * top_k + j) + + # Reorder tokens block by block, padding if needed + reordered_token_ids = [] + reordered_expert_ids = [] + + for e_id in range(num_experts): + tokens_for_expert = expert_to_tokens[e_id] + num_tokens = len(tokens_for_expert) + + n_blocks = ((num_tokens + block_size - 1) // block_size) + # If not a multiple of block_size, pad up to the next multiple + padded_size = n_blocks * block_size + + # Reorder all actual tokens for expert e_id + reordered_token_ids.extend(tokens_for_expert) + # reordered_expert_ids.extend([e_id]*num_tokens) + reordered_expert_ids.extend([e_id] * n_blocks) + + # Pad with dummy token_id = topk_ids.numel() + if padded_size > num_tokens: + pad_count = padded_size - num_tokens + reordered_token_ids.extend([topk_ids.numel()] * pad_count) + + token_length = len(reordered_token_ids) + expert_length = len(reordered_expert_ids) + + sorted_token_ids[:token_length] = torch.tensor(reordered_token_ids, dtype=sorted_token_ids.dtype, + device=sorted_token_ids.device) + expert_ids[:expert_length] = torch.tensor(reordered_expert_ids, dtype=expert_ids.dtype, device=expert_ids.device) + + # Fill remainder with topk_ids.numel() if these arrays are bigger than total_length + if token_length < sorted_token_ids.numel(): + sorted_token_ids[token_length:] = topk_ids.numel() + if expert_length < expert_ids.numel(): + expert_ids[expert_length:] = topk_ids.numel() + + num_tokens_post_pad.fill_(token_length) + + +def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. + """ + top_k = topk_ids.shape[1] + sorted_ids = torch.empty((topk_ids.numel() + num_experts * (block_size - 1), ), dtype=torch.int32, + device=topk_ids.device) + expert_ids = torch.empty((topk_ids.numel() + num_experts, ), dtype=torch.int32, device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + _moe_align_block_size(topk_ids, num_experts, top_k, block_size, sorted_ids, expert_ids, num_tokens_post_pad) + + return sorted_ids, expert_ids, num_tokens_post_pad + +def get_default_config() -> Dict[str, int]: + config = {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8} + return config + +def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, compute_type, fp8: bool): + if fp8: + a = torch.randn((M, K), dtype=compute_type, device='cuda') + a = a.to(torch.float8_e4m3fnuz) + b = torch.rand((E, N, K), dtype=compute_type, device='cuda') + b = b.to(torch.float8_e4m3fnuz) + else: + b = torch.randn((E, N, K), dtype=compute_type, device='cuda') + a = torch.randn((M, K), dtype=compute_type, device='cuda') + c = torch.zeros((M, top_k, N), dtype=compute_type, device='cuda') + + if fp8: + a_scale = torch.randn((1), dtype=torch.float32, device='cuda') + b_scale = torch.randn((E), dtype=torch.float32, device='cuda') + else: + a_scale = None + b_scale = None + + values = torch.randn(M, E, dtype=compute_type, device='cuda') + + softmax_vals = torch.softmax(values, dim=1) + topk_weights, topk_ids = torch.topk(softmax_vals, k=top_k, dim=1) + + config = get_default_config() + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], E) + + if not routed_weight: + return a, b, c, None, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config, a_scale, b_scale + + return a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config, a_scale, b_scale + + +torch_to_tl_dtype = {torch.float16 : tl.float16, torch.bfloat16 : tl.bfloat16, torch.float32 : tl.float32} + +@pytest.mark.parametrize("M, N, K, top_k, E", [(64, 14336, 4096, 2, 8), (16, 14336, 1, 2, 4), (4, 4, 8, 1, 2), + (1, 14336, 128, 2, 4), (3, 14336, 128, 2, 4), (16, 14336, 128, 1, 4), + (16, 14336, 128, 1, 1), (64, 7186, 128, 2, 8), (64, 3584, 128, 2, 8), + (64, 1792, 128, 2, 8), (64, 64, 128, 2, 8), (1, 1024, 16384, 1, 2)]) +@pytest.mark.parametrize('routed_weight', [True, False]) +@pytest.mark.parametrize('fp8', [(False)]) #TODO add support for fp8 +def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, fp8: bool, compute_type=torch.bfloat16): + #torch.manual_seed(20) + a, b, triton_out, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config, a_scale, b_scale = input_helper( + M, N, K, top_k, E, routed_weight=routed_weight, compute_type=compute_type, fp8=fp8) + + print(f"sorted_token_ids={sorted_token_ids}") + print(f"expert_ids={expert_ids}") + print(f"num_tokens_post_padded={num_tokens_post_padded}") + moe_triton(a, b, triton_out, a_scale, b_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, + num_tokens_post_padded, routed_weight, top_k, config, torch_to_tl_dtype[compute_type], fp8, False) + + torch_out = torch.empty_like(triton_out) + torch_out = torch_moe(a, b, torch_out, topk_ids, topk_weights, routed_weight, sorted_token_ids, expert_ids, + num_tokens_post_padded, a_scale, b_scale, compute_type, fp8) + + # Validate correctness + torch.testing.assert_close(triton_out, torch_out, atol=1e-2, rtol=1e-2) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3f9439a..3800156 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ pandas<=2.0.3 +pytest