Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] bfloat16 x int8 GEMM #1936

Open
sycz00 opened this issue Nov 11, 2024 · 0 comments
Open

[QST] bfloat16 x int8 GEMM #1936

sycz00 opened this issue Nov 11, 2024 · 0 comments

Comments

@sycz00
Copy link

sycz00 commented Nov 11, 2024

What is your question?
Hey folks,

I was inspired by the following PR . Trying to follow the examples, I am not sure how to achieve it.
My goal is to perform a GEMM as used for a linear layer in which the weights are stored as integer 8 values and the activations are integer values higher 8-bitwidth's but stored in BF16 or F16 format.

Starting with the following cuda extension python interface, no valid combination can be found.

plan = cutlass.op.Gemm(
    element_A=torch.int8,
    element_B=torch.bfloat16,
    element_C=torch.int8,
    element_D=torch.int32,
    element_accumulator=torch.int32,
    #element_D=cutlass.DataType.s32,
    layout=cutlass.LayoutType.RowMajor)
op = plan.construct()

mod = cutlass.emit.pytorch(op, name='gemm', cc=plan.cc, jit=True)

If not possible via the python interface, how could it be done using the following cuda file ? The reason I am having issues is that I don't fully understand the example in folder 55 as it is different as opposed to the one below:

// This file was automatically generated by the CUTLASS 3.6.0 Python interface (https://github.com/nvidia/cutlass/python)

#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"

// helper function allocating the memory
void* device_memory_allocation(size_t size, int device_id=0) {
    if (size > 0) {
        torch::Device device(torch::kCUDA, device_id);
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();
        torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
        at::Tensor device_tensor = torch::empty({(long)size,}, options);
        return reinterpret_cast<void*>(device_tensor.data_ptr());
    } else {
        return nullptr;
    }
}


#include "cutlass/gemm/device/gemm_universal.h"


// Gemm operator cutlass_simt_s8_igemm_s8_128x128_8x2_tt_align1
using DeviceKernel =
    typename cutlass::gemm::device::GemmUniversal<
        // Data type and layout of operand A
        int8_t, cutlass::layout::RowMajor,
        // Data type and layout of operand B
        int8_t, cutlass::layout::RowMajor,
        // Data type and layout of operand C
        int8_t, cutlass::layout::RowMajor,
        // Data type of accumulator
        int32_t,
        // Class of operation
        cutlass::arch::OpClassSimt,
        // Compute capability of the target kernel
        cutlass::arch::Sm80,
        // Threadblock tile shape
        cutlass::gemm::GemmShape<128, 128, 8>,
        // Warp tile shape
        cutlass::gemm::GemmShape<32, 64, 8>,
        // Instruction shape
        cutlass::gemm::GemmShape<1, 1, 1>,
        // Epilogue functor
        cutlass::epilogue::thread::LinearCombination<int32_t, 1, int32_t, int32_t>,
        // Swizzling function
        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
        // Number of pipeline stages
        2,
        // Alignment of operands A and B
        1, 1,
        // Type of math operation
        cutlass::arch::OpMultiplyAdd,
        // Complex transform types of operands A and B
        cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone
    >;


using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;

cutlass::Status gemm_mod_kernel_run(int M, int N, int K,
                        const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
                        ElementCompute alpha, ElementCompute beta) {
  
  typename DeviceKernel::Arguments arguments {
      cutlass::gemm::GemmUniversalMode::kGemm,
      {M, N, K},                                        // problem size
      1,
      {alpha, beta},
      A, B, C, D,
      0, 0, 0, 0,                                       // batch strides
      DeviceKernel::LayoutA::packed({M, K}).stride(0),  // lda
      DeviceKernel::LayoutB::packed({K, N}).stride(0),  // ldb
      DeviceKernel::LayoutC::packed({M, N}).stride(0),  // ldc
      DeviceKernel::LayoutC::packed({M, N}).stride(0)   // ldd
  };

  size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  DeviceKernel gemm_op;
  cutlass::Status status = gemm_op.initialize(arguments,
                                              workspace.get(),
                                              nullptr);     // CUDA stream

  if (status != cutlass::Status::kSuccess) {
    return status;
  }

  status = gemm_op();
  return status;
}

at::Tensor gemm_mod_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
    int M = A.size(0);
    int N = B.size(1);
    int K = A.size(1);

    typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
                                            nullptr :
                                            reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
    at::Tensor D = B.new_empty({M, N}, torch::kI8);

    cutlass::Status status = gemm_mod_kernel_run(M, N, K,
                                                reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
                                                reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
                                                ptrC,
                                                reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
                                                ElementCompute(alpha), ElementCompute(beta));

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant