Skip to content

Commit

Permalink
Use ark abbreviation
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Mar 22, 2024
1 parent c856e4e commit 20840bb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
13 changes: 7 additions & 6 deletions ark/gpu/gpu_loop_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ GpuLoopKernel::GpuLoopKernel(std::shared_ptr<GpuContext> ctx,
"#define ARK_THREADS_PER_BLOCK " << block_dim_[0] << "\n"
"__device__ int _ITER = 0;\n"
"#include \"ark_kernels.h\"\n"
"__device__ ark::sync::State " ARK_LSS_NAME ";\n"
"using namespace ark;\n"
"__device__ sync::State " ARK_LSS_NAME ";\n"
"__device__ char *" ARK_BUF_NAME ";\n"
<< *ark_loop_body_code <<
"extern \"C\" __global__ __launch_bounds__(" << block_dim_[0] << ", 1)\n"
Expand All @@ -81,21 +82,21 @@ GpuLoopKernel::GpuLoopKernel(std::shared_ptr<GpuContext> ctx,
" for (;;) {\n"
" if (threadIdx.x == 0 && blockIdx.x == 0) {\n"
" int iter;\n"
" while ((iter = ark::atomicLoadRelaxed(_it)) == 0) {}\n"
" while ((iter = atomicLoadRelaxed(_it)) == 0) {}\n"
" _ITER = iter;\n"
" }\n"
" ark::sync_gpu<" << num_sm << ">(" ARK_LSS_NAME ");\n"
" sync_gpu<" << num_sm << ">(" ARK_LSS_NAME ");\n"
" if (_ITER < 0) {\n"
" return;\n"
" }\n"
" for (int _i = 0; _i < _ITER; ++_i) {\n"
" ark_loop_body(_buf, _i);\n"
" ark::sync_gpu<" << num_sm << ">(" ARK_LSS_NAME ");\n"
" sync_gpu<" << num_sm << ">(" ARK_LSS_NAME ");\n"
" }\n"
" if (threadIdx.x == 0 && blockIdx.x == 0) {\n"
" ark::atomicStoreRelaxed(_it, 0);\n"
" atomicStoreRelaxed(_it, 0);\n"
" }\n"
" ark::sync_gpu<" << num_sm << ">(" ARK_LSS_NAME ");\n"
" sync_gpu<" << num_sm << ">(" ARK_LSS_NAME ");\n"
" }\n"
"}\n";
// clang-format on
Expand Down
14 changes: 7 additions & 7 deletions ark/include/ark/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ std::ostream &operator<<(std::ostream &os, const TensorType &type);
}; \
const TensorType_##_type_name _type_name;

REGISTER_TENSOR_TYPE(FP32, 4, "float")
REGISTER_TENSOR_TYPE(FP16, 2, "ark::fp16")
REGISTER_TENSOR_TYPE(BF16, 2, "ark::bf16")
REGISTER_TENSOR_TYPE(INT32, 4, "int32_t")
REGISTER_TENSOR_TYPE(UINT32, 4, "uint32_t")
REGISTER_TENSOR_TYPE(INT8, 1, "int8_t")
REGISTER_TENSOR_TYPE(UINT8, 1, "uint8_t")
REGISTER_TENSOR_TYPE(FP32, 4, "fp32")
REGISTER_TENSOR_TYPE(FP16, 2, "fp16")
REGISTER_TENSOR_TYPE(BF16, 2, "bf16")
REGISTER_TENSOR_TYPE(INT32, 4, "i32")
REGISTER_TENSOR_TYPE(UINT32, 4, "ui32")
REGISTER_TENSOR_TYPE(INT8, 1, "i8")
REGISTER_TENSOR_TYPE(UINT8, 1, "ui8")
REGISTER_TENSOR_TYPE(BYTE, 1, "unsigned char")

class GpuBuffer;
Expand Down

0 comments on commit 20840bb

Please sign in to comment.