You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was inspired by the following PR . Trying to follow the examples, I am not sure how to achieve it.
My goal is to perform a GEMM as used for a linear layer in which the weights are stored as integer 8 values and the activations are integer values higher 8-bitwidth's but stored in BF16 or F16 format.
Starting with the following cuda extension python interface, no valid combination can be found.
If not possible via the python interface, how could it be done using the following cuda file ? The reason I am having issues is that I don't fully understand the example in folder 55 as it is different as opposed to the one below:
// This file was automatically generated by the CUTLASS 3.6.0 Python interface (https://github.com/nvidia/cutlass/python)
#include<cuda_runtime.h>
#include<torch/extension.h>
#include<ATen/ATen.h>
#include<ATen/cuda/CUDAContext.h>
#include"cutlass/cutlass.h"
#include"cutlass/util/device_memory.h"// helper function allocating the memoryvoid* device_memory_allocation(size_t size, int device_id=0) {
if (size > 0) {
torch::Device device(torch::kCUDA, device_id);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
at::Tensor device_tensor = torch::empty({(long)size,}, options);
returnreinterpret_cast<void*>(device_tensor.data_ptr());
} else {
returnnullptr;
}
}
#include"cutlass/gemm/device/gemm_universal.h"// Gemm operator cutlass_simt_s8_igemm_s8_128x128_8x2_tt_align1using DeviceKernel =
typename cutlass::gemm::device::GemmUniversal<
// Data type and layout of operand Aint8_t, cutlass::layout::RowMajor,
// Data type and layout of operand Bint8_t, cutlass::layout::RowMajor,
// Data type and layout of operand Cint8_t, cutlass::layout::RowMajor,
// Data type of accumulatorint32_t,
// Class of operation
cutlass::arch::OpClassSimt,
// Compute capability of the target kernel
cutlass::arch::Sm80,
// Threadblock tile shape
cutlass::gemm::GemmShape<128, 128, 8>,
// Warp tile shape
cutlass::gemm::GemmShape<32, 64, 8>,
// Instruction shape
cutlass::gemm::GemmShape<1, 1, 1>,
// Epilogue functor
cutlass::epilogue::thread::LinearCombination<int32_t, 1, int32_t, int32_t>,
// Swizzling function
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
// Number of pipeline stages2,
// Alignment of operands A and B1, 1,
// Type of math operation
cutlass::arch::OpMultiplyAdd,
// Complex transform types of operands A and B
cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone
>;
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status gemm_mod_kernel_run(int M, int N, int K,
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
ElementCompute alpha, ElementCompute beta) {
typename DeviceKernel::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K}, // problem size1,
{alpha, beta},
A, B, C, D,
0, 0, 0, 0, // batch stridesDeviceKernel::LayoutA::packed({M, K}).stride(0), // ldaDeviceKernel::LayoutB::packed({K, N}).stride(0), // ldbDeviceKernel::LayoutC::packed({M, N}).stride(0), // ldcDeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
};
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.initialize(arguments,
workspace.get(),
nullptr); // CUDA streamif (status != cutlass::Status::kSuccess) {
return status;
}
status = gemm_op();
return status;
}
at::Tensor gemm_mod_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
int M = A.size(0);
int N = B.size(1);
int K = A.size(1);
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
at::Tensor D = B.new_empty({M, N}, torch::kI8);
cutlass::Status status = gemm_mod_kernel_run(M, N, K,
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
ptrC,
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta));
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
The text was updated successfully, but these errors were encountered:
What is your question?
Hey folks,
I was inspired by the following PR . Trying to follow the examples, I am not sure how to achieve it.
My goal is to perform a GEMM as used for a linear layer in which the weights are stored as integer 8 values and the activations are integer values higher 8-bitwidth's but stored in BF16 or F16 format.
Starting with the following cuda extension python interface, no valid combination can be found.
If not possible via the python interface, how could it be done using the following cuda file ? The reason I am having issues is that I don't fully understand the example in folder 55 as it is different as opposed to the one below:
The text was updated successfully, but these errors were encountered: