You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
intconst m = 16;
intconst n = 16;
intconst k = 8;
intconst batch_count = 65536;
Anything below, like batch_count=65536-1 works.
Steps/Code to reproduce bug
I managed to reduce the failing code to this:
#include<iostream>
#include"cutlass/cutlass.h"
#include"cutlass/gemm/device/gemm_batched.h"
#include"cutlass/layout/matrix.h"
cudaError_t cutlass_strided_batched_sgemm(int m, int n, int k, float alpha, floatconst* A, int lda,
longlongint batch_stride_A, floatconst* B, int ldb,
longlongint batch_stride_B, float* C, int ldc,
longlongint batch_stride_C, float beta,
int batch_count) {
using Gemm = cutlass::gemm::device::GemmBatched<float, cutlass::layout::ColumnMajor, float,
cutlass::layout::ColumnMajor, float,
cutlass::layout::ColumnMajor>;
Gemm gemm_op;
cutlass::Status status = gemm_op({{m, n, k},
{A, lda},
batch_stride_A,
{B, ldb},
batch_stride_B,
{C, ldc},
batch_stride_C,
{C, ldc},
batch_stride_C,
{alpha, beta},
batch_count});
if (status != cutlass::Status::kSuccess) {
std::cerr << "Cutlass failed with error string " << cutlass::cutlassGetStatusString(status)
<< std::endl;
return cudaErrorUnknown;
}
return cudaSuccess;
}
cudaError_t run_batched_gemm() {
// Arbitrary problem sizeintconst m = 16;
intconst n = 16;
intconst k = 8;
intconst batch_count = 65536;
// A, B are non-transpose, column majorintconst lda = m;
intconst ldb = k * batch_count;
intconst ldc = m;
intconst count_A = batch_count * lda * k;
intconst count_B = ldb * n;
intconst count_C = batch_count * ldc * n;
// the memory is batched along K dimensionlonglongint batch_stride_A = static_cast<longlongint>(lda) * static_cast<longlongint>(k);
longlongint batch_stride_B = static_cast<longlongint>(k);
longlongint batch_stride_C = static_cast<longlongint>(ldc) * static_cast<longlongint>(n);
// alpha and betafloat alpha = 1.0f;
float beta = 2.0f;
cudaError_t result = cudaSuccess;
// // allocate the device memoryfloat* A;
float* B;
float* C;
result = cudaMalloc(&A, count_A * sizeof(float));
result = cudaMalloc(&B, count_B * sizeof(float));
result = cudaMalloc(&C, count_C * sizeof(float));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}
result =
cutlass_strided_batched_sgemm(m, n, k, alpha, A, lda, batch_stride_A, B, ldb,
batch_stride_B, C, ldc, batch_stride_C, beta, batch_count);
return result;
}
intmain() {
cudaError_t result = run_batched_gemm();
if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
} else {
std::cout << "There was an error." << std::endl;
}
return result == cudaSuccess ? 0 : -1;
}
Expected behavior
I would not expect this API call to have a maximum number of blocks so low. Or at least have a more informative error status if that is the actual limit.
Environment details (please complete the following information):
Environment location: Bare-metal, a CUDA 12.3 installation with the current master, a8f2c80
Additional context
The error also arises in the v3.3.0 git tag.
The text was updated successfully, but these errors were encountered:
block number is assigned to blockIdx.z which is 16bit. you could split the batch to 2 kernels to run. should not impact the performance with this large batch number.
Thanks for the quick answer!
I suspected it had to do with block sizes. I was surprised about the error type.
Did I missed something in the documentation about it? Sorry, I am still learning to navigate it.
Thanks for the good work, you are awesome.
Describe the bug
I am trying to run the example https://github.com/NVIDIA/cutlass/blob/main/examples/05_batched_gemm/batched_gemm.cu
Which I am compiling to run on an RTX 4090 with:
nvcc -arch=sm_89 batched_gemm.cu -I${CUTLASS_ROOT}/include -run
The code runs well until I change the problem size:
cutlass/examples/05_batched_gemm/batched_gemm.cu
Lines 246 to 250 in a8f2c80
to:
Anything below, like
batch_count=65536-1
works.Steps/Code to reproduce bug
I managed to reduce the failing code to this:
Expected behavior
I would not expect this API call to have a maximum number of blocks so low. Or at least have a more informative error status if that is the actual limit.
Environment details (please complete the following information):
Additional context
The error also arises in the v3.3.0 git tag.
The text was updated successfully, but these errors were encountered: