diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int2_bf16_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int2_bf16_gemm.cu new file mode 100644 index 0000000000..5346da62f8 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int2_bf16_gemm.cu @@ -0,0 +1,668 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example shows how to perform INT4 x BF16 GEMM and scale up the INT4 weight during dequantization. + + The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap + A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue, + as illustrated in this example. + + Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest. + + As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. + This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads + 4 groups of 2 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-a for thread-value layout). + If the narrow type is INT4 and tensor is major in K dim, only 8 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput. + If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized + tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the macro + OPTIMIZE_WEIGHT_LAYOUT is set to 1. + + It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size). + + Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled. + + If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k]. + + The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size + equal to the gemm problem K. + + Limitations: + 1) Only supports INT4 x { FP16, BF16 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported. + 2) The INT4 weights have additional encoding requirements. + 3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 4) The scales must have the same layout and groupsize. + 5) The groupsize must be greater or equal to the tile shape k. + 6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the + operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. + We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. + + Optimizing suggestions: + 1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space). + + Examples: + + Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0) + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0 + + Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire + matrix (group size is the same as the gemm k dimension). + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "helper.h" +#include "unfused_weight_dequantize.hpp" +#include "packed_scale.hpp" +#include "reorder_utils.hpp" + +using namespace cute; + +#define OPTIMIZE_WEIGHT_LAYOUT 1 + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +using MmaType = cutlass::bfloat16_t; +using QuantType = cutlass::int2b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; + +#if OPTIMIZE_WEIGHT_LAYOUT +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory. +// It specifies the reordering within a single warp's fragment +using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout, StrideB>{})); +#endif + +using ElementScale = MmaType; +using ElementZero = ElementScale; // only for verify +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C matrix. + // We can enable this if beta == 0 by changing ElementC to void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule // This is the only epi supporting the required swap + transpose. + >::CollectiveOp; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, +#if OPTIMIZE_WEIGHT_LAYOUT + cute::tuple, LayoutB_Reordered, AlignmentB, +#else + cute::tuple, LayoutB_Transpose, AlignmentB, +#endif + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideC = typename GemmKernelScaleOnly::StrideC; +using StrideD = typename GemmKernelScaleOnly::StrideD; + +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideC_ref stride_C_ref; +StrideD stride_D; +StrideD_ref stride_D_ref; +uint64_t seed; + +#if OPTIMIZE_WEIGHT_LAYOUT +LayoutB_Reordered layout_B_reordered; +#endif + +using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; +StrideS stride_S; +StrideS_ref stride_S_ref; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 10; + int m = 5120, n = 4096, k = 4096; + int g = 128; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("g", g); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_warp_specialized_gemm\n\n" + << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_quant_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + +// Print element type + std::cout << "Element type: " << typeid(Element).name() << std::endl; + // For more readable output + std::cout << "Element size in bits: " << cutlass::sizeof_bits::value << std::endl; + + float scope_min = -1.f; + float scope_max = 1.f; + + std::cout << "scope_min: " << scope_min << "\n"; + std::cout << "scope_max: " << scope_max << "\n"; + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + Options const& options) { + + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + float const max_dequant_val = 2.f; + float const min_dequant_val = 0.75f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + Options const& options) { + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options const& options) { + + auto shape_B = cute::make_shape(options.n, options.k, options.l); + int const scale_k = (options.k + options.g - 1) / options.g; + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); + // Reverse stride here due to swap and transpose + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); + // Reverse stride here due to swap and transpose + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); + + auto layout_B = make_layout(shape_B, stride_B); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); + + block_scale.reset(scale_k * options.l * options.n); + block_zero.reset(scale_k * options.l * options.n); + + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_zero(block_zero, options); + + auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); + stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); + stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); + auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); + + + dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); + + +#if OPTIMIZE_WEIGHT_LAYOUT + // Repeat the reorder layout atom to tile the whole tensor shape + layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); + reorder_tensor(block_B.get(), layout_B, layout_B_reordered); + + print("Quantized tensor layout: "); + print(layout_B_reordered); + print("\n"); +#endif +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +Args args_from_options(Options const& options) +{ +// Swap the A and B tensors, as well as problem shapes here. + + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, +#if OPTIMIZE_WEIGHT_LAYOUT + {block_B.get(), layout_B_reordered, block_A.get(), stride_A, block_scale.get(), stride_S, options.g}, +#else + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g}, +#endif + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; +} + +bool verify(Options const& options) { + // + // Compute reference output + // + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementD const epsilon(1.0); + ElementD const non_zero_floor(0.2); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // {$nv-internal-release begin} + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + // {$nv-internal-release end} + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + if (options.g == options.k) { + std::cout << "Running in per-column scale mode." << std::endl; + } else { + std::cout << "Running in group scale mode." << std::endl; + } + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt index 23dca4f3fd..95a8991ef6 100644 --- a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt +++ b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt @@ -43,6 +43,7 @@ set(TEST_SCALE_ZERO_RESIDUE --m=128 --n=128 --k=192 --g=128 --mode=2 --iteration set(TEST_ALPHA_BETA --alpha=0.5 --beta=0.7 --mode=2 --iterations=0) # Alpha and Beta with default shapes +#set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g -G") cutlass_example_add_executable( 55_hopper_mixed_dtype_gemm @@ -79,3 +80,14 @@ cutlass_example_add_executable( TEST_SCALE_RESIDUE # TEST_ALPHA_BETA ) + +cutlass_example_add_executable( + 55_hopper_int2_bf16_gemm + 55_hopper_int2_bf16_gemm.cu + TEST_COMMAND_OPTIONS + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL + TEST_SCALE_GROUP + TEST_SCALE_RESIDUE + # TEST_ALPHA_BETA +) diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 17c1ac14d6..42fac4a518 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -4108,6 +4108,144 @@ struct NumericArrayConverter { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_16 = Array; + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_16 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_16 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch."); + + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted_two = src_reg >> 2; + uint32_t src_reg_shifted_four = src_reg >> 4; + uint32_t src_reg_shifted_six = src_reg >> 6; + + // Modified prmt indices for signed 2-bit values + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + + static_assert(RegArray::kElements <= 8, "Too many inputs for BF16 -> SI2 vector converter"); + + // First pass: extract and sign extend the 2-bit values + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted_two), "r"(prmt_indices[ii / 2])); + + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii + 1]) + : "r"(src_reg_shifted_four), "r"(src_reg_shifted_six), "r"(prmt_indices[ii / 2])); + } + + // For signed 2-bit integers: + // 00 -> 0 (0) + // 01 -> 1 (1) + // 10 -> -2 (2 with sign extension) + // 11 -> -1 (3 with sign extension) + //static constexpr uint32_t sign_mask = 0x00020002; // Mask to check sign bit + static constexpr uint32_t and_mask = 0x00030003; // Mask for 2 bits + + // Modified for signed range (-2 to 1) + // We'll construct numbers in the form 128 + (x + 2) and then subtract 130 + // to get back to our original range + static constexpr uint32_t xor_mask = 0x43024302; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // Bias represents 130 in bfloat16 format + // Subtracting 130 brings us back to our signed range (-2 to 1) + static constexpr uint32_t bias_rep = 0x43024302; // {130, 130} in bfloat16 + const __nv_bfloat162& bias = reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + /// Partial specialization for Array <= Array template struct NumericArrayConverter {