diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..993a0d8dd --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,33 @@ +{ + "configurations": [ + { + "name": "ops_matmul_test", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/build/ark/ops_matmul_test.cu", + "args": [], + "stopAtEntry": false, + "cwd": "${fileDirname}", + "environment": [ + { + "name": "ARK_ROOT", + "value": "${workspaceFolder}/build" + } + ], + "externalConsole": false, + "MIMode": "gdb", + "setupCommands": [ + { + "description": "Enable pretty-printing for gdb", + "text": "-enable-pretty-printing", + "ignoreFailures": true + }, + { + "description": "Set Disassembly Flavor to Intel", + "text": "-gdb-set disassembly-flavor intel", + "ignoreFailures": true + } + ] + } + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json index c19b9e274..855fdf595 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,4 +3,7 @@ "cmake.environment": { "ARK_ROOT": "${workspaceFolder}/build" }, + "cmake.ctestArgs": [ + "--verbose" + ], } diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c916a56b..337d3b6af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,9 +11,11 @@ set(ARK_SOVERSION "${ARK_MAJOR}.${ARK_MINOR}") option(USE_KAHYPAR "Use KaHyPar for scheduling" OFF) cmake_minimum_required(VERSION 3.25) -project(ark LANGUAGES CXX) -set(CMAKE_CXX_STANDARD 14) +project(ark LANGUAGES CXX CUDA) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall,-Wextra") set(BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}) # Find ibverbs @@ -24,6 +26,19 @@ include(${PROJECT_SOURCE_DIR}/cmake/FindNUMA.cmake) # Find CUDAToolkit find_package(CUDAToolkit REQUIRED) +if(CUDAToolkit_FOUND) + if(CUDAToolkit_VERSION_MAJOR LESS 11) + message(FATAL_ERROR "CUDA 11 or higher is required but detected ${CUDAToolkit_VERSION}") + endif() + + if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 11) + set(CMAKE_CUDA_ARCHITECTURES 70 80) + endif() + + if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 12) + set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES} 70 80 90) + endif() +endif() # Third party libraries add_subdirectory(third_party) diff --git a/ark/CMakeLists.txt b/ark/CMakeLists.txt index c7b0f3721..2f2275f6c 100644 --- a/ark/CMakeLists.txt +++ b/ark/CMakeLists.txt @@ -2,7 +2,7 @@ # Licensed under the MIT license. file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc) -file(GLOB_RECURSE UT_SOURCES CONFIGURE_DEPENDS *_test.cc) +file(GLOB_RECURSE UT_SOURCES CONFIGURE_DEPENDS *_test.cc *_test.cu) file(GLOB_RECURSE UT_COMMON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/unittest/*.cc) list(REMOVE_ITEM SOURCES ${UT_SOURCES} ${UT_COMMON_SOURCES}) file(GLOB_RECURSE INTERFACE_HEADERS CONFIGURE_DEPENDS include/ark*.h) @@ -61,7 +61,7 @@ foreach(ut_source IN ITEMS ${UT_SOURCES}) add_executable(${exe_name} ${ut_source} ${UT_COMMON_SOURCES}) add_dependencies(${exe_name} build) set_target_properties(${exe_name} PROPERTIES EXCLUDE_FROM_ALL TRUE) - target_link_libraries(${exe_name} PRIVATE ark_obj ${COMMON_LIBS}) + target_link_libraries(${exe_name} PRIVATE ark_obj ${COMMON_LIBS} CUDA::cudart CUDA::cublas) target_include_directories(${exe_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(${exe_name} SYSTEM PRIVATE ${PROJECT_SOURCE_DIR}/third_party/json diff --git a/ark/gpu/gpu_compile.cc b/ark/gpu/gpu_compile.cc index 85180805a..e8c0011bd 100644 --- a/ark/gpu/gpu_compile.cc +++ b/ark/gpu/gpu_compile.cc @@ -16,6 +16,7 @@ #include "gpu/gpu_compile.h" #include "gpu/gpu_logging.h" #include "include/ark.h" +#include "random.h" #include "threading.h" #define ARK_USE_NVRTC 0 @@ -27,21 +28,6 @@ using namespace std; -// Generate a random alpha-numeric string. -static const string rand_anum(size_t len) -{ - auto randchar = []() -> char { - const char charset[] = "0123456789" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz"; - const size_t max_index = sizeof(charset) - 1; - return charset[rand() % max_index]; - }; - string str(len, 0); - generate_n(str.begin(), len, randchar); - return str; -} - namespace ark { #if (ARK_USE_NVRTC) diff --git a/ark/gpu/gpu_kernel.cc b/ark/gpu/gpu_kernel.cc index a20d7c172..5b8caa540 100644 --- a/ark/gpu/gpu_kernel.cc +++ b/ark/gpu/gpu_kernel.cc @@ -86,32 +86,28 @@ void GpuKernel::compile(const GpuInfo &gpu_info, bool use_comm_sw) this->cubin = gpu_compile(this->codes, gpu_info.arch, max_reg_cnt, use_comm_sw); } -} -// -void GpuKernel::load() -{ // - unsigned int buflen = 8192; - char *infobuf = new char[buflen]; - char *errbuf = new char[buflen]; - assert(infobuf != nullptr); - assert(errbuf != nullptr); + size_t num_opts = 5; + size_t buflen = 8192; + std::unique_ptr opts(new CUjit_option[num_opts]); + std::unique_ptr optvals(new void *[num_opts]); + std::string infobuf; + std::string errbuf; + + infobuf.resize(buflen, ' '); + errbuf.resize(buflen, ' '); + int enable = 1; - int num_opts = 5; - CUjit_option *opts = new CUjit_option[num_opts]; - void **optvals = new void *[num_opts]; - assert(opts != nullptr); - assert(optvals != nullptr); opts[0] = CU_JIT_INFO_LOG_BUFFER; - optvals[0] = (void *)infobuf; + optvals[0] = (void *)infobuf.data(); opts[1] = CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES; optvals[1] = (void *)(long)buflen; opts[2] = CU_JIT_ERROR_LOG_BUFFER; - optvals[2] = (void *)errbuf; + optvals[2] = (void *)errbuf.data(); opts[3] = CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES; optvals[3] = (void *)(long)buflen; @@ -119,13 +115,11 @@ void GpuKernel::load() opts[4] = CU_JIT_GENERATE_DEBUG_INFO; optvals[4] = (void *)(long)enable; - if (cuModuleLoadDataEx(&this->module, this->cubin.c_str(), num_opts, opts, - optvals) != CUDA_SUCCESS) { + if (cuModuleLoadDataEx(&this->module, this->cubin.c_str(), num_opts, + opts.get(), optvals.get()) != CUDA_SUCCESS) { LOG(DEBUG, infobuf); LOG(ERROR, "cuModuleLoadDataEx() failed: ", errbuf); } - delete[] infobuf; - delete[] errbuf; CULOG(cuModuleGetFunction(&this->kernel, this->module, this->name.c_str())); // int static_smem_size_bytes; @@ -136,8 +130,6 @@ void GpuKernel::load() CULOG(cuFuncSetAttribute(this->kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, dynamic_smem_size_bytes)); - // Now code string is not needed. - // this->code.clear(); } // @@ -265,7 +257,6 @@ void GpuLoopKernel::compile(const GpuInfo &gpu_info) void GpuLoopKernel::load() { this->ctx->set_current(); - GpuKernel::load(); // if (!this->is_compiled()) { LOG(ERROR, "Need to compile first before initialization."); @@ -368,25 +359,10 @@ GpuState GpuLoopKernel::launch(CUstream stream, bool disable_timing) void GpuLoopKernel::run(int iter) { if (iter > 0) { -#if 0 - int idx = this->flip_flag ? 0 : 1; - int rem = iter; - while (rem--) { - while (this->get_flag(idx) > 0) { - cpu_ntimer_sleep(500); - } - this->set_flag(idx, 1); - idx ^= 1; - } - if (iter & 1) { - this->flip_flag = !(this->flip_flag); - } -#else volatile int *href = this->flag_href; while (*href > 0) { } *href = iter; -#endif } } diff --git a/ark/gpu/gpu_kernel.h b/ark/gpu/gpu_kernel.h index 648adcc88..292be6e2b 100644 --- a/ark/gpu/gpu_kernel.h +++ b/ark/gpu/gpu_kernel.h @@ -33,7 +33,6 @@ class GpuKernel ~GpuKernel(); void compile(const GpuInfo &gpu_info, bool use_comm_sw = true); - void load(); GpuState launch(GpuStream stream); const std::string &get_name() diff --git a/ark/include/ark.h b/ark/include/ark.h index d34a9a60c..3e036dfe8 100644 --- a/ark/include/ark.h +++ b/ark/include/ark.h @@ -171,9 +171,21 @@ class Tensor /// After read, the data in the host buffer will be 0, 1, 2, 4, 5, 6. /// /// @param buf The host buffer to copy to. The buffer must be large enough - /// to hold the data. + /// to hold the data. If @p buf is nullptr, a new buffer will be allocated. + /// @return The host buffer that holds the data. + /// + void *read(void *buf = nullptr); + + /// Copy all the underlying buffer data (including padding) to a contiguous + /// host buffer. + /// + /// This function is mainly for debugging purposes. + /// + /// @param buf The host buffer to copy to. The buffer must be large enough + /// to hold the data. If @p buf is nullptr, a new buffer will be allocated. + /// @return The host buffer that holds the data. /// - void read(void *buf); + void *read_raw(void *buf = nullptr); /// Set all bytes of the tensor buffer to 0. void clear(); diff --git a/ark/include/ark_utils.h b/ark/include/ark_utils.h index 96a054ceb..1cea610ef 100644 --- a/ark/include/ark_utils.h +++ b/ark/include/ark_utils.h @@ -51,6 +51,8 @@ template std::unique_ptr rand_array(size_t num, float max_val) std::unique_ptr rand_halfs(size_t num, float max_val); // Return a random float array. std::unique_ptr rand_floats(size_t num, float max_val); +// Return a random bytes array. +std::unique_ptr rand_bytes(size_t num); // Return a half_t range array. std::unique_ptr range_halfs(size_t num, float begin = 1.0f, @@ -79,32 +81,6 @@ template std::unique_ptr ones(size_t num) return std::unique_ptr(ret); } -// Return the error rate between two values. -float error_rate(half_t a, half_t b); -float error_rate(float a, float b); - -// Return mean squared error and max error rate between two matrices. -std::pair cmp_matrix(half_t *ground_truth, half_t *res, - unsigned int m, unsigned int n, - unsigned int bs = 1, unsigned int lm = 0, - unsigned int ln = 0, bool print = false); -std::pair cmp_matrix(float *ground_truth, float *res, - unsigned int m, unsigned int n, - unsigned int bs = 1, unsigned int lm = 0, - unsigned int ln = 0, bool print = false); - -// Print a matrix. -void print_matrix(half_t *val, unsigned int m, unsigned int n, unsigned int bs, - unsigned int lm, unsigned int ln); -void print_matrix(float *val, unsigned int m, unsigned int n, unsigned int bs, - unsigned int lm, unsigned int ln); - -// -std::pair tensor_compare(half_t *ground_truth, half_t *res, - Dims shape, bool print); -std::pair tensor_compare(float *ground_truth, float *res, - Dims shape, bool print); - // Spawn a process that runs `func`. Returns PID of the spawned process. int proc_spawn(const std::function &func); // Wait for a spawned process with PID `pid`. diff --git a/ark/include/kernels/unit_op.h b/ark/include/kernels/unit_op.h index 167b15d7b..21206303a 100644 --- a/ark/include/kernels/unit_op.h +++ b/ark/include/kernels/unit_op.h @@ -65,9 +65,10 @@ struct UnitOp static_assert(_SmemBytes >= 0, "Bytes of shared memory is negative"); // Number of unit operators in each dimension. - using UnitOpDims = - Vec<_OutDims::N / _UnitOutDims::N, _OutDims::C / _UnitOutDims::C, - _OutDims::H / _UnitOutDims::H, _OutDims::W / _UnitOutDims::W>; + using UnitOpDims = Vec::value, + math::div_up<_OutShape::C, _UnitOutDims::C>::value, + math::div_up<_OutShape::H, _UnitOutDims::H>::value, + math::div_up<_OutShape::W, _UnitOutDims::W>::value>; static const int NumThreads = _NumThreads; static const int SmemBytes = _SmemBytes; diff --git a/ark/ops/ops_add_test.cc b/ark/ops/ops_add_test.cc index 5e61dc334..b44a09989 100644 --- a/ark/ops/ops_add_test.cc +++ b/ark/ops/ops_add_test.cc @@ -2,41 +2,113 @@ // Licensed under the MIT license. #include "include/ark.h" +#include "include/ark_utils.h" #include "ops_test_common.h" #include "unittest/unittest_utils.h" +template +void baseline_add(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + T *out = static_cast(outputs[0]); + T *t0 = static_cast(inputs[0]); + T *t1 = static_cast(inputs[1]); + + // NumPy-style broadcasted addition + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish0 = input_shapes[0].dims4(); + ark::Dims ish1 = input_shapes[1].dims4(); + for (ark::DimType n = 0; n < osh[0]; ++n) { + for (ark::DimType c = 0; c < osh[1]; ++c) { + for (ark::DimType h = 0; h < osh[2]; ++h) { + for (ark::DimType w = 0; w < osh[3]; ++w) { + out[w + h * osh[3] + c * osh[2] * osh[3] + + n * osh[1] * osh[2] * osh[3]] = + t0[(w % ish0[3]) + (h % ish0[2]) * ish0[3] + + (c % ish0[1]) * ish0[2] * ish0[3] + + (n % ish0[0]) * ish0[1] * ish0[2] * ish0[3]] + + t1[(w % ish1[3]) + (h % ish1[2]) * ish1[3] + + (c % ish1[1]) * ish1[2] * ish1[3] + + (n % ish1[0]) * ish1[1] * ish1[2] * ish1[3]]; + } + } + } + } +}; + ark::unittest::State test_add_fp32() { - test_bcast_fp32("add", 2, 1024, 512); - test_bcast_fp32("add", 1, 1, 64); - test_bcast_fp32("add", 1, 128, 128); - test_bcast_fp32("add", 1, 1024, 512); - test_bcast_fp32("add", 1, 512, 1024); - test_bcast_fp32("add", 2, 1, 64); - test_bcast_fp32("add", 2, 128, 128); - test_bcast_fp32("add", 4, 1024, 512); - test_bcast_fp32("add", 4, 512, 1024); + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(8192), ark::FP32); + ark::Tensor *t1 = m.tensor(ark::Dims(8192), ark::FP32); + ark::Tensor *out = m.add(t0, t1); + + auto result = + ark::op_test("add_fp32", m, {t0, t1}, {out}, baseline_add); + ark::op_test_log(result); return ark::unittest::SUCCESS; } ark::unittest::State test_add_fp16() { - test_bcast_fp16("add", 1, 1, 2); - test_bcast_fp16("add", 1, 1, 64); - test_bcast_fp16("add", 1, 128, 128); - test_bcast_fp16("add", 1, 1024, 512); - test_bcast_fp16("add", 1, 512, 1024); - test_bcast_fp16("add", 2, 1, 64); - test_bcast_fp16("add", 2, 128, 128); - test_bcast_fp16("add", 4, 1024, 512); - test_bcast_fp16("add", 4, 512, 1024); + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(8192), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(8192), ark::FP16); + ark::Tensor *out = m.add(t0, t1); + + auto result = + ark::op_test("add_fp16", m, {t0, t1}, {out}, baseline_add); + ark::op_test_log(result); return ark::unittest::SUCCESS; } ark::unittest::State test_add_overwrite() { - test_bcast_fp32("add", 2, 1024, 512, true); - test_bcast_fp16("add", 2, 1024, 512, true); + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(8192), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(8192), ark::FP16); + ark::Tensor *out = m.add(t0, t1, t1); + + auto result = ark::op_test("add_overwrite", m, {t0, t1}, {out}, + baseline_add); + ark::op_test_log(result); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_add_broadcast() +{ + { + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(4, 1024), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(1, 1024), ark::FP16); + ark::Tensor *out = m.add(t0, t1); + + auto result = ark::op_test("add_broadcast", m, {t0, t1}, {out}, + baseline_add); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(4, 1024), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(4, 1), ark::FP16); + ark::Tensor *out = m.add(t0, t1); + + auto result = ark::op_test("add_broadcast", m, {t0, t1}, {out}, + baseline_add); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(3, 1, 1024), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(1, 4, 1), ark::FP16); + ark::Tensor *out = m.add(t0, t1); + + auto result = ark::op_test("add_broadcast", m, {t0, t1}, {out}, + baseline_add); + ark::op_test_log(result); + } return ark::unittest::SUCCESS; } @@ -46,5 +118,6 @@ int main() UNITTEST(test_add_fp32); UNITTEST(test_add_fp16); UNITTEST(test_add_overwrite); + UNITTEST(test_add_broadcast); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_all_reduce_test.cc b/ark/ops/ops_all_reduce_test.cc index dc8cfd07f..2c9642b52 100644 --- a/ark/ops/ops_all_reduce_test.cc +++ b/ark/ops/ops_all_reduce_test.cc @@ -5,12 +5,12 @@ #include "include/ark.h" #include "include/ark_utils.h" #include "logging.h" +#include "ops_test_common.h" #include "unittest/unittest_utils.h" using namespace std; using namespace ark; -// used for print the all_reduce result and check the correctness -// #define PRINT_MATRIX + void test_all_reduce_internal(size_t bytes, int num_gpus, int iter) { // bytes/num_gpus is the number of bytes a GPU send in one iteration, the @@ -40,22 +40,7 @@ void test_all_reduce_internal(size_t bytes, int num_gpus, int iter) } gt[i] = ark::half_t(sum); } -#ifdef PRINT_MATRIX - // print input data. - for (int gpu_id = 0; gpu_id < num_gpus; gpu_id++) { - cout << "input data of gpu_id: " << gpu_id << endl; - for (size_t i = 0; i < bytes / sizeof(ark::half_t) && i < 10; i++) { - cout << (float)input_data[gpu_id].get()[i] << " "; - } - cout << endl; - } - // print ground truth. - cout << "ground truth: " << endl; - for (size_t i = 0; i < bytes / sizeof(ark::half_t) && i < 10; i++) { - cout << (float)gt[i] << " "; - } - cout << endl; -#endif + for (int gpu_id = 0; gpu_id < num_gpus; ++gpu_id) { ark::unittest::spawn_process([gpu_id, num_gpus, &input_data, >, bytes, iter]() { @@ -84,25 +69,13 @@ void test_all_reduce_internal(size_t bytes, int num_gpus, int iter) allreduce_result->read(res); // Compare results with the ground truth. - auto p = ark::utils::cmp_matrix((ark::half_t *)gt, - (ark::half_t *)res, 1, bytes / 2); -#ifdef PRINT_MATRIX - // print result, to avoid too long output, only print the first 10 - // elements if(gpu_id == 0) - { - cout << "result on gpu_id: " << gpu_id << " "; - for (size_t i = 0; i < bytes / sizeof(ark::half_t) && i < 10; - i++) { - cout << (float)res[i] << " "; - } - cout << endl; - } -#endif + auto comp = tensor_compare(gt, res, allreduce_result->shape); + free(res); LOG(ark::INFO, " all_reduce on gpu: ", gpu_id, " num_gpus: ", num_gpus, " total_bytes: ", bytes, - " iter: ", iter, setprecision(4), " mse: ", p.first, - " max_err: ", p.second * 100, "%", + " iter: ", iter, setprecision(4), " mse: ", comp.mse, + " max_err: ", comp.max_error_rate * 100, "%", " elapsed_msec: ", elapsed_msec, "ms"); return ark::unittest::SUCCESS; }); diff --git a/ark/ops/ops_gelu_test.cc b/ark/ops/ops_gelu_test.cc index c129e1140..2b8813bd7 100644 --- a/ark/ops/ops_gelu_test.cc +++ b/ark/ops/ops_gelu_test.cc @@ -1,85 +1,46 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu/gpu_kernel.h" #include "include/ark.h" #include "include/ark_utils.h" -#include "logging.h" +#include "ops_test_common.h" #include "unittest/unittest_utils.h" #include -using namespace std; - float gelu(float x) { return 0.5 * x * (1 + tanh(sqrt(2 / M_PI) * (x + 0.044715 * pow(x, 3)))); } -// -void test_gelu_internal(unsigned int bs, unsigned int n, unsigned int m) +template +void baseline_gelu(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &) { - unsigned int len = bs * m * n; - // Set data. - ark::srand(); - auto data_x = ark::utils::rand_halfs(len, 0.01); - - // Get ground truth - void *gt = malloc(len * sizeof(ark::half_t)); - UNITTEST_NE(gt, (void *)nullptr); - for (unsigned int i = 0; i < len; ++i) { - ((ark::half_t *)gt)[i] = gelu(((ark::half_t *)data_x.get())[i]); + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + ark::Dims osh = output_shapes[0]; + for (ark::DimType i = 0; i < osh.size(); ++i) { + out[i] = gelu(input[i]); } - // - ark::Model model; - ark::Tensor *tns_x = model.tensor({bs, n, m}, ark::FP16); - ark::Tensor *tns_y = model.gelu(tns_x); - - // - ark::Executor exe{0, 0, 1, model, "test_gelu"}; - exe.compile(); - - // Set data. - tns_x->write(data_x.get()); - - exe.launch(); - exe.run(1); - exe.stop(); - - // Copy results of the loop kernel routine into CPU memory. - void *res = malloc(len * sizeof(ark::half_t)); - UNITTEST_NE(res, (void *)nullptr); - tns_y->read(res); +}; - // Compare results with the ground truth. - auto p = - ark::utils::cmp_matrix((ark::half_t *)gt, (ark::half_t *)res, m, n, bs); - float max_err = p.second; - LOG(ark::INFO, "gelu:", n, 'x', m, ",bs=", bs, setprecision(4), " mse ", - p.first, " max_err ", max_err * 100, "%"); - - free(res); - free(gt); - - UNITTEST_EQ(max_err, 0.0); -} - -ark::unittest::State test_gelu() +ark::unittest::State test_gelu_fp32() { - test_gelu_internal(1, 1, 64); - test_gelu_internal(1, 64, 64); - test_gelu_internal(1, 128, 128); - test_gelu_internal(1, 4096, 1024); - test_gelu_internal(1, 1024, 4096); - test_gelu_internal(2, 1, 64); - test_gelu_internal(2, 128, 128); - test_gelu_internal(8, 4096, 1024); - test_gelu_internal(8, 1024, 4096); + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); + ark::Tensor *out = m.gelu(t); + + auto result = + ark::op_test("gelu_fp32", m, {t}, {out}, baseline_gelu); + ark::op_test_log(result); return ark::unittest::SUCCESS; } int main() { ark::init(); - UNITTEST(test_gelu); + UNITTEST(test_gelu_fp32); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_im2col_test.cc b/ark/ops/ops_im2col_test.cc index be5436e15..027c72b3c 100644 --- a/ark/ops/ops_im2col_test.cc +++ b/ark/ops/ops_im2col_test.cc @@ -5,6 +5,7 @@ #include "include/ark.h" #include "include/ark_utils.h" #include "logging.h" +#include "ops_test_common.h" #include "unittest/unittest_utils.h" using namespace std; @@ -35,8 +36,6 @@ void test_im2col_internal(ark::DimType n, ark::DimType h, ark::DimType w, ark::utils::range_halfs(tns_x->shape_bytes(), 0.00001, 0.00001); tns_x->write(data_x.get()); - // ark::utils::print_matrix(data_x.get(), h * w, c, h * w, c); - exe.launch(); exe.run(1); exe.stop(); @@ -89,16 +88,15 @@ void test_im2col_internal(ark::DimType n, ark::DimType h, ark::DimType w, } // Compare results with the ground truth. - auto p = ark::utils::cmp_matrix((ark::half_t *)gt, (ark::half_t *)res, mdim, - inner_dim, n, mdim, inner_dim); - float max_err = p.second; + auto comp = tensor_compare(gt, res, tns_y->shape); + float max_err = comp.max_error_rate; stringstream ss; ss << "im2col:n=" << n << ",c=" << c << ",h=" << h << ",w=" << w << ",kh=" << kernel_height << ",kw=" << kernel_width << ",sh=" << stride_height << ",sw=" << stride_width << ",ph=" << pad_height << ",pw=" << pad_width << ",dh=" << dilation_height << ",dw=" << dilation_width - << setprecision(4) << " mse " << p.first << " max_err " << max_err * 100 + << setprecision(4) << " mse " << comp.mse << " max_err " << max_err * 100 << "%"; LOG(ark::INFO, ss.str()); diff --git a/ark/ops/ops_matmul_test.cc b/ark/ops/ops_matmul_test.cc deleted file mode 100644 index a03c3baa5..000000000 --- a/ark/ops/ops_matmul_test.cc +++ /dev/null @@ -1,365 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "gpu/gpu_kernel.h" -#include "include/ark.h" -#include "include/ark_utils.h" -#include "logging.h" -#include "unittest/unittest_utils.h" -#include - -using namespace std; - -// m,n,k: Problem size. CAUTION: `m` and `n` are assumed to be multiple of 16. -// bs_a: Batch size of left-side matrix. -// bs_b: Batch size of right-side matrix. -// iter: Number of iterations. -void test_matmul_internal(unsigned int m, unsigned int n, unsigned int k, - unsigned int bs_a, unsigned int bs_b, int split_k = 1, - int gran_lev = -1, unsigned int iter = 1) -{ - assert(bs_a == bs_b || bs_a == 1 || bs_b == 1); - unsigned int bs_res = bs_a > bs_b ? bs_a : bs_b; - - ark::GpuMgr *mgr = ark::get_gpu_mgr(0); - ark::GpuMgrCtx *ctx = mgr->create_context("test_simple_matmul_nt", 0, 1); - - size_t buf_a_sz = - (size_t)bs_a * (size_t)m * (size_t)k * sizeof(ark::half_t); - size_t buf_b_sz = - (size_t)bs_b * (size_t)k * (size_t)n * sizeof(ark::half_t); - size_t buf_res_sz = - (size_t)bs_res * (size_t)m * (size_t)n * sizeof(ark::half_t); - - // Reserved GPU buffers for execution of a manually written kernel, - // `simple_matmul_nt`. - ark::GpuBuf *buf_a = ctx->mem_alloc(buf_a_sz); - ark::GpuBuf *buf_b = ctx->mem_alloc(buf_b_sz); - ark::GpuBuf *buf_gt = ctx->mem_alloc(buf_res_sz); - // ark::GpuBuf *buf_res = ctx->mem_alloc(buf_res_sz); - - ctx->freeze(); - - bool is_relu = false; - - // Define `simple_matmul_nt` kernel to generate the ground truth. - ark::GpuKernel gk{"simple_matmul_nt", - {ark::unittest::get_kernel_code("simple_matmul_nt")}, - {n / 16, m / 16, 1}, - {16, 16, 1}, - 0, - {buf_gt, buf_a, buf_b}, - {}, - {{&m, sizeof(m)}, - {&n, sizeof(n)}, - {&k, sizeof(k)}, - {&is_relu, sizeof(is_relu)}}, - ""}; - gk.compile(mgr->get_gpu_info()); - gk.load(); - - // Generate random data for tests. - ark::srand(); - auto data_a = ark::utils::rand_halfs(buf_a_sz / sizeof(ark::half_t), 0.001); - auto data_b = ark::utils::rand_halfs(buf_b_sz / sizeof(ark::half_t), 0.001); - ark::gpu_memcpy(buf_a, data_a.get(), buf_a_sz); - ark::gpu_memcpy(buf_b, data_b.get(), buf_b_sz); - - // Run the GPU kernel. - ark::GpuStream s = ctx->create_stream(); - int ret = gk.launch(s); - UNITTEST_EQ(ret, 0); - ret = ctx->sync_stream(s); - UNITTEST_EQ(ret, 0); - - // Copy the ground truth results into CPU memory. - void *gt = malloc(buf_res_sz); - UNITTEST_NE(gt, (void *)nullptr); - ark::gpu_memcpy(gt, buf_gt, buf_res_sz); - - // Declare an equivalent matmul using Model APIs. - ark::Model model; - ark::Tensor *tns_a = model.tensor({m, k}, ark::FP16); - ark::Tensor *tns_b = model.tensor({k, n}, ark::FP16); - ark::Tensor *tns_res = model.matmul(tns_a, tns_b, nullptr, split_k, false, - false, "matmul", gran_lev); - - mgr->destroy_context(ctx); - - // - ark::Executor exe{0, 0, 1, model, "test_matmul_nt"}; - exe.compile(); - - tns_a->write(data_a.get()); - tns_b->write(data_b.get()); - - exe.launch(); - exe.run(iter); - float elapsed = exe.stop(); - - // Copy results of the loop kernel routine into CPU memory. - void *res = malloc(buf_res_sz); - UNITTEST_NE(res, (void *)nullptr); - tns_res->read(res); - - // Calculate CPU results - // float temp; - // unsigned int h; - // unsigned int w; - // for (unsigned int i = 0; i < (size_t)bs_res * (size_t)m * (size_t)n; ++i) - // { - // temp = 0; - // h = i % m; - // w = i / m; - // for (unsigned int j = 0; j < k; ++j) { - // temp += (float)(data_a.get()[j * m + h]) * (float)(data_b.get()[j - // * n + w]); - // } - // ((ark::half_t *)gt)[i] = ark::half_t(temp); - // } - - // Compare results with the ground truth. - auto p = - ark::utils::cmp_matrix((ark::half_t *)gt, (ark::half_t *)res, m, n); - float max_err = p.second; - LOG(ark::INFO, "matmul:", m, 'x', n, 'x', k, "(split_k=", split_k, - ",gran_lev=", gran_lev, ") ", setprecision(4), " mse ", p.first, - " max_err ", max_err * 100, "%", " elapsed ", elapsed, "ms iter ", - iter); - - free(res); - free(gt); - - UNITTEST_EQ(max_err, 0.0); -} - -ark::unittest::State test_matmul_gran0() -{ - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/0); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/0); - - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/0); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/0); - test_matmul_internal(/*m=*/256, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/0); - - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/256, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/1, /*gran_lev=*/0); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/1, /*gran_lev=*/0); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_gran1() -{ - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/1); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/1); - - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/1); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/1); - test_matmul_internal(/*m=*/256, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/1); - - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/256, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/1, /*gran_lev=*/1); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/1, /*gran_lev=*/1); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_gran2() -{ - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/2); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/2); - - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/2); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/2); - test_matmul_internal(/*m=*/256, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/2); - - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/256, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/1, - /*gran_lev=*/2); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/1, /*gran_lev=*/2); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_gran3() -{ - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/3); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/3); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/3); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/32, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/3); - - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/3); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/3); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/3); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/3); - test_matmul_internal(/*m=*/256, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/3); - - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/256, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/1, /*gran_lev=*/3); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/1, /*gran_lev=*/3); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_split() -{ - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/2, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/2, /*gran_lev=*/2); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/2, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/2, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/256, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/2, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/2, /*gran_lev=*/2); - - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/128, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/4, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/128, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/4, /*gran_lev=*/2); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/128, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/4, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/128, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/4, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/256, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/4, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/4, /*gran_lev=*/2); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/4, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/4, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/4, /*gran_lev=*/2); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/3, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/3, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/3, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/3, /*gran_lev=*/2); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/5, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/5, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/5, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/5, /*gran_lev=*/2); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/6, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/6, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/6, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/6, /*gran_lev=*/2); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/7, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/7, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/7, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/7, /*gran_lev=*/2); - - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/8, /*gran_lev=*/0); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/8, /*gran_lev=*/1); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/8, /*gran_lev=*/2); - test_matmul_internal(/*m=*/128, /*n=*/4096, /*k=*/1024, /*bs_a=*/1, - /*bs_b=*/1, /*split_k=*/8, /*gran_lev=*/2); - - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_perf() -{ - test_matmul_internal(/*m=*/64, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/-1, /*iter=*/1000); - test_matmul_internal(/*m=*/64, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/-1, /*iter=*/1000); - test_matmul_internal(/*m=*/128, /*n=*/64, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/-1, /*iter=*/1000); - test_matmul_internal(/*m=*/128, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/-1, /*iter=*/1000); - test_matmul_internal(/*m=*/256, /*n=*/128, /*k=*/64, /*bs_a=*/1, /*bs_b=*/1, - /*split_k=*/1, /*gran_lev=*/-1, /*iter=*/1000); - return ark::unittest::SUCCESS; -} - -int main() -{ - ark::init(); - UNITTEST(test_matmul_gran0); - UNITTEST(test_matmul_gran1); - UNITTEST(test_matmul_gran2); - // UNITTEST(test_matmul_gran3); - UNITTEST(test_matmul_split); - UNITTEST(test_matmul_perf); - return ark::unittest::SUCCESS; -} diff --git a/ark/ops/ops_matmul_test.cu b/ark/ops/ops_matmul_test.cu new file mode 100644 index 000000000..04c8d6b11 --- /dev/null +++ b/ark/ops/ops_matmul_test.cu @@ -0,0 +1,483 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark_utils.h" +#include "ops_test_common.h" +#include +#include +#include + +cublasHandle_t globalCublasHandle = nullptr; + +cublasHandle_t get_cublas_handle() +{ + if (globalCublasHandle == nullptr) { + cublasStatus_t status = cublasCreate(&globalCublasHandle); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to create cublas handle"); + } + } + return globalCublasHandle; +} + +void cublas_matmul_float_nn(int m, int n, int k, const float *a, int lda, + const float *b, int ldb, float *c, int ldc) +{ + auto cublasH = get_cublas_handle(); + float alpha = 1; + float beta = 0; + cublasStatus_t status = + cublasSgemm(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, + a, lda, &beta, c, ldc); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to call cublasSgemm"); + } +} + +void cublas_matmul_float_nt(int m, int n, int k, const float *a, int lda, + const float *b, int ldb, float *c, int ldc) +{ + auto cublasH = get_cublas_handle(); + float alpha = 1; + float beta = 0; + cublasStatus_t status = + cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, + a, lda, &beta, c, ldc); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to call cublasSgemm"); + } +} + +void cublas_matmul_float_tn(int m, int n, int k, const float *a, int lda, + const float *b, int ldb, float *c, int ldc) +{ + auto cublasH = get_cublas_handle(); + float alpha = 1; + float beta = 0; + cublasStatus_t status = + cublasSgemm(cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, + a, lda, &beta, c, ldc); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to call cublasSgemm"); + } +} + +void cublas_matmul_float_tt(int m, int n, int k, const float *a, int lda, + const float *b, int ldb, float *c, int ldc) +{ + auto cublasH = get_cublas_handle(); + float alpha = 1; + float beta = 0; + cublasStatus_t status = + cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, + a, lda, &beta, c, ldc); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to call cublasSgemm"); + } +} + +void cublas_matmul_half_nn(int m, int n, int k, const half *a, int lda, + const half *b, int ldb, half *c, int ldc) +{ + auto cublasH = get_cublas_handle(); + half alpha = half(ark::half_t(1)); + half beta = half(ark::half_t(0)); + cublasStatus_t status = + cublasHgemm(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, + a, lda, &beta, c, ldc); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to call cublasHgemm"); + } +} + +void cublas_matmul_half_nt(int m, int n, int k, const half *a, int lda, + const half *b, int ldb, half *c, int ldc) +{ + auto cublasH = get_cublas_handle(); + half alpha = half(ark::half_t(1)); + half beta = half(ark::half_t(0)); + cublasStatus_t status = + cublasHgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, + a, lda, &beta, c, ldc); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to call cublasHgemm"); + } +} + +void cublas_matmul_half_tn(int m, int n, int k, const half *a, int lda, + const half *b, int ldb, half *c, int ldc) +{ + auto cublasH = get_cublas_handle(); + half alpha = half(ark::half_t(1)); + half beta = half(ark::half_t(0)); + cublasStatus_t status = + cublasHgemm(cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, + a, lda, &beta, c, ldc); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to call cublasHgemm"); + } +} + +void cublas_matmul_half_tt(int m, int n, int k, const half *a, int lda, + const half *b, int ldb, half *c, int ldc) +{ + auto cublasH = get_cublas_handle(); + half alpha = half(ark::half_t(1)); + half beta = half(ark::half_t(0)); + cublasStatus_t status = + cublasHgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, + a, lda, &beta, c, ldc); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to call cublasHgemm"); + } +} + +template +void baseline_matmul_nn(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + // baseline inputs & outputs have no padding + int m = output_shapes[0].dims4()[2]; + int n = output_shapes[0].dims4()[3]; + int k = input_shapes[0].dims4()[3]; + int lda = k; + int ldb = n; + int ldc = n; + + auto memA = ark::to_gpu(inputs[0], input_shapes[0].size() * sizeof(T)); + auto memB = ark::to_gpu(inputs[1], input_shapes[1].size() * sizeof(T)); + auto memC = ark::to_gpu(outputs[0], output_shapes[0].size() * sizeof(T)); + + T *devA = static_cast(memA.get()); + T *devB = static_cast(memB.get()); + T *devC = static_cast(memC.get()); + + // matmul using cublas + if constexpr (std::is_same_v) { + cublas_matmul_float_nn(m, n, k, devA, lda, devB, ldb, devC, ldc); + } else if constexpr (std::is_same_v) { + cublas_matmul_half_nn(m, n, k, devA, lda, devB, ldb, devC, ldc); + } else { + throw std::runtime_error("Unsupported data type"); + } + ark::sync_gpu(); + + // copy back to host + ark::from_gpu(memC, outputs[0]); +} + +template +void baseline_matmul_nt(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + // baseline inputs & outputs have no padding + int m = output_shapes[0].dims4()[2]; + int n = output_shapes[0].dims4()[3]; + int k = input_shapes[0].dims4()[3]; + int lda = k; + int ldb = k; + int ldc = n; + + auto memA = ark::to_gpu(inputs[0], input_shapes[0].size() * sizeof(T)); + auto memB = ark::to_gpu(inputs[1], input_shapes[1].size() * sizeof(T)); + auto memC = ark::to_gpu(outputs[0], output_shapes[0].size() * sizeof(T)); + + T *devA = static_cast(memA.get()); + T *devB = static_cast(memB.get()); + T *devC = static_cast(memC.get()); + + // matmul using cublas + if constexpr (std::is_same_v) { + cublas_matmul_float_nt(m, n, k, devA, lda, devB, ldb, devC, ldc); + } else if constexpr (std::is_same_v) { + cublas_matmul_half_nt(m, n, k, devA, lda, devB, ldb, devC, ldc); + } else { + throw std::runtime_error("Unsupported data type"); + } + ark::sync_gpu(); + + // copy back to host + ark::from_gpu(memC, outputs[0]); +} + +template +void baseline_matmul_tn(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + // baseline inputs & outputs have no padding + int m = output_shapes[0].dims4()[2]; + int n = output_shapes[0].dims4()[3]; + int k = input_shapes[0].dims4()[2]; + int lda = m; + int ldb = n; + int ldc = n; + + auto memA = ark::to_gpu(inputs[0], input_shapes[0].size() * sizeof(T)); + auto memB = ark::to_gpu(inputs[1], input_shapes[1].size() * sizeof(T)); + auto memC = ark::to_gpu(outputs[0], output_shapes[0].size() * sizeof(T)); + + T *devA = static_cast(memA.get()); + T *devB = static_cast(memB.get()); + T *devC = static_cast(memC.get()); + + // matmul using cublas + if constexpr (std::is_same_v) { + cublas_matmul_float_tn(m, n, k, devA, lda, devB, ldb, devC, ldc); + } else if constexpr (std::is_same_v) { + cublas_matmul_half_tn(m, n, k, devA, lda, devB, ldb, devC, ldc); + } else { + throw std::runtime_error("Unsupported data type"); + } + ark::sync_gpu(); + + // copy back to host + ark::from_gpu(memC, outputs[0]); +} + +template +void baseline_matmul_tt(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + // baseline inputs & outputs have no padding + int m = output_shapes[0].dims4()[2]; + int n = output_shapes[0].dims4()[3]; + int k = input_shapes[0].dims4()[2]; + int lda = m; + int ldb = k; + int ldc = n; + + auto memA = ark::to_gpu(inputs[0], input_shapes[0].size() * sizeof(T)); + auto memB = ark::to_gpu(inputs[1], input_shapes[1].size() * sizeof(T)); + auto memC = ark::to_gpu(outputs[0], output_shapes[0].size() * sizeof(T)); + + T *devA = static_cast(memA.get()); + T *devB = static_cast(memB.get()); + T *devC = static_cast(memC.get()); + + // matmul using cublas + if constexpr (std::is_same_v) { + cublas_matmul_float_tt(m, n, k, devA, lda, devB, ldb, devC, ldc); + } else if constexpr (std::is_same_v) { + cublas_matmul_half_tt(m, n, k, devA, lda, devB, ldb, devC, ldc); + } else { + throw std::runtime_error("Unsupported data type"); + } + ark::sync_gpu(); + + // copy back to host + ark::from_gpu(memC, outputs[0]); +} + +ark::unittest::State test_matmul_gran0() +{ + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 0); + + auto result = ark::op_test("matmul_gran0", m, {a, b}, {c}, + baseline_matmul_nn); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(4096, 8192), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 0); + + auto result = ark::op_test("matmul_gran0", m, {a, b}, {c}, + baseline_matmul_nn); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_matmul_gran1() +{ + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 1); + + auto result = ark::op_test("matmul_gran1", m, {a, b}, {c}, + baseline_matmul_nn); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(4096, 8192), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 1); + + auto result = ark::op_test("matmul_gran1", m, {a, b}, {c}, + baseline_matmul_nn); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_matmul_gran2() +{ + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 2); + + auto result = ark::op_test("matmul_gran2", m, {a, b}, {c}, + baseline_matmul_nn); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(4096, 8192), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 2); + + auto result = ark::op_test("matmul_gran2", m, {a, b}, {c}, + baseline_matmul_nn); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_matmul_split() +{ + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(4096, 8192), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 7, false, false, "matmul", 0); + + auto result = ark::op_test("matmul_split", m, {a, b}, {c}, + baseline_matmul_nn); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_matmul_fp32() +{ + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::FP32); + ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::FP32); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 0); + + auto result = ark::op_test("matmul_fp32", m, {a, b}, {c}, + baseline_matmul_nn); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(4096, 8192), ark::FP32); + ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::FP32); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 0); + + auto result = ark::op_test("matmul_fp32", m, {a, b}, {c}, + baseline_matmul_nn); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_matmul_nt() +{ + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(256, 64), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, true, "matmul", 0); + + auto result = + ark::op_test("matmul_nt", m, {a, b}, {c}, baseline_matmul_nt); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(4096, 8192), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(16384, 8192), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, true, "matmul", 0); + + auto result = + ark::op_test("matmul_nt", m, {a, b}, {c}, baseline_matmul_nt); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_matmul_tn() +{ + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(64, 128), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, true, false, "matmul", 0); + + auto result = + ark::op_test("matmul_tn", m, {a, b}, {c}, baseline_matmul_tn, "ones", true); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(8192, 4096), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, true, false, "matmul", 0); + + auto result = + ark::op_test("matmul_tn", m, {a, b}, {c}, baseline_matmul_tn); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_matmul_tt() +{ + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(64, 128), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(256, 64), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, true, true, "matmul", 0); + + auto result = + ark::op_test("matmul_tt", m, {a, b}, {c}, baseline_matmul_tt); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(8192, 4096), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(16384, 8192), ark::FP16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, true, true, "matmul", 0); + + auto result = + ark::op_test("matmul_tt", m, {a, b}, {c}, baseline_matmul_tt); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + +int main() +{ + ark::init(); + UNITTEST(test_matmul_gran0); + UNITTEST(test_matmul_gran1); + UNITTEST(test_matmul_gran2); + UNITTEST(test_matmul_split); + UNITTEST(test_matmul_fp32); + UNITTEST(test_matmul_nt); + UNITTEST(test_matmul_tn); + UNITTEST(test_matmul_tt); + + cublasDestroy(get_cublas_handle()); + return ark::unittest::SUCCESS; +} diff --git a/ark/ops/ops_mul_test.cc b/ark/ops/ops_mul_test.cc index 29b5456d3..b8b9af97a 100644 --- a/ark/ops/ops_mul_test.cc +++ b/ark/ops/ops_mul_test.cc @@ -2,41 +2,113 @@ // Licensed under the MIT license. #include "include/ark.h" +#include "include/ark_utils.h" #include "ops_test_common.h" #include "unittest/unittest_utils.h" +template +void baseline_mul(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + T *out = static_cast(outputs[0]); + T *t0 = static_cast(inputs[0]); + T *t1 = static_cast(inputs[1]); + + // NumPy-style broadcasted multiplication + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish0 = input_shapes[0].dims4(); + ark::Dims ish1 = input_shapes[1].dims4(); + for (ark::DimType n = 0; n < osh[0]; ++n) { + for (ark::DimType c = 0; c < osh[1]; ++c) { + for (ark::DimType h = 0; h < osh[2]; ++h) { + for (ark::DimType w = 0; w < osh[3]; ++w) { + out[w + h * osh[3] + c * osh[2] * osh[3] + + n * osh[1] * osh[2] * osh[3]] = + t0[(w % ish0[3]) + (h % ish0[2]) * ish0[3] + + (c % ish0[1]) * ish0[2] * ish0[3] + + (n % ish0[0]) * ish0[1] * ish0[2] * ish0[3]] * + t1[(w % ish1[3]) + (h % ish1[2]) * ish1[3] + + (c % ish1[1]) * ish1[2] * ish1[3] + + (n % ish1[0]) * ish1[1] * ish1[2] * ish1[3]]; + } + } + } + } +}; + ark::unittest::State test_mul_fp32() { - test_bcast_fp32("mul", 1, 1, 2); - test_bcast_fp32("mul", 1, 1, 64); - test_bcast_fp32("mul", 1, 128, 128); - test_bcast_fp32("mul", 1, 1024, 512); - test_bcast_fp32("mul", 1, 512, 1024); - test_bcast_fp32("mul", 2, 1, 64); - test_bcast_fp32("mul", 2, 128, 128); - test_bcast_fp32("mul", 4, 1024, 512); - test_bcast_fp32("mul", 4, 512, 1024); + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(8192), ark::FP32); + ark::Tensor *t1 = m.tensor(ark::Dims(8192), ark::FP32); + ark::Tensor *out = m.mul(t0, t1); + + auto result = + ark::op_test("mul_fp32", m, {t0, t1}, {out}, baseline_mul); + ark::op_test_log(result); return ark::unittest::SUCCESS; } ark::unittest::State test_mul_fp16() { - test_bcast_fp16("mul", 1, 1, 2); - test_bcast_fp16("mul", 1, 1, 64); - test_bcast_fp16("mul", 1, 128, 128); - test_bcast_fp16("mul", 1, 1024, 512); - test_bcast_fp16("mul", 1, 512, 1024); - test_bcast_fp16("mul", 2, 1, 64); - test_bcast_fp16("mul", 2, 128, 128); - test_bcast_fp16("mul", 4, 1024, 512); - test_bcast_fp16("mul", 4, 512, 1024); + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(8192), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(8192), ark::FP16); + ark::Tensor *out = m.mul(t0, t1); + + auto result = + ark::op_test("mul_fp16", m, {t0, t1}, {out}, baseline_mul); + ark::op_test_log(result); return ark::unittest::SUCCESS; } ark::unittest::State test_mul_overwrite() { - test_bcast_fp32("mul", 2, 1024, 512, true); - test_bcast_fp16("mul", 2, 1024, 512, true); + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(8192), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(8192), ark::FP16); + ark::Tensor *out = m.mul(t0, t1, t1); + + auto result = ark::op_test("mul_overwrite", m, {t0, t1}, {out}, + baseline_mul); + ark::op_test_log(result); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_mul_broadcast() +{ + { + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(4, 1024), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(1, 1024), ark::FP16); + ark::Tensor *out = m.mul(t0, t1); + + auto result = ark::op_test("mul_broadcast", m, {t0, t1}, {out}, + baseline_mul); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(4, 1024), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(4, 1), ark::FP16); + ark::Tensor *out = m.mul(t0, t1); + + auto result = ark::op_test("mul_broadcast", m, {t0, t1}, {out}, + baseline_mul); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(3, 1, 1024), ark::FP16); + ark::Tensor *t1 = m.tensor(ark::Dims(1, 4, 1), ark::FP16); + ark::Tensor *out = m.mul(t0, t1); + + auto result = ark::op_test("mul_broadcast", m, {t0, t1}, {out}, + baseline_mul); + ark::op_test_log(result); + } return ark::unittest::SUCCESS; } @@ -46,5 +118,6 @@ int main() UNITTEST(test_mul_fp32); UNITTEST(test_mul_fp16); UNITTEST(test_mul_overwrite); + UNITTEST(test_mul_broadcast); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_reduce_test.cc b/ark/ops/ops_reduce_test.cc index dbebe21f6..df63ea13f 100644 --- a/ark/ops/ops_reduce_test.cc +++ b/ark/ops/ops_reduce_test.cc @@ -1,119 +1,219 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu/gpu_kernel.h" #include "include/ark.h" #include "include/ark_utils.h" -#include "logging.h" +#include "ops_test_common.h" #include "unittest/unittest_utils.h" #include -using namespace std; +template +void baseline_reduce_sum_axis0(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish = input_shapes[0].dims4(); + + assert(osh[0] == 1); + + for (ark::DimType c = 0; c < ish[1]; ++c) { + for (ark::DimType h = 0; h < ish[2]; ++h) { + for (ark::DimType w = 0; w < ish[3]; ++w) { + T sum = 0; + for (ark::DimType n = 0; n < ish[0]; ++n) { + sum += input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]; + } + out[c * osh[2] * osh[3] + h * osh[3] + w] = sum; + } + } + } +} -// -void test_reduce_internal(unsigned int n, unsigned int m, unsigned int k, - int axis) +template +void baseline_reduce_sum_axis1(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) { - size_t buf_x_sz = (size_t)m * (size_t)n * (size_t)k * sizeof(ark::half_t); - size_t buf_y_sz = (size_t)m * (size_t)n * sizeof(ark::half_t); - - // Set data. - ark::srand(); - auto data_a = ark::utils::rand_halfs(buf_x_sz / sizeof(ark::half_t), 0.01); - - // Copy the ground truth results into CPU memory. - void *gt = malloc(buf_y_sz); - UNITTEST_NE(gt, (void *)nullptr); - - for (unsigned int i = 0; i < n; ++i) { - for (unsigned int j = 0; j < m; ++j) { - ark::half_t v = 0; - for (unsigned int l = 0; l < k; ++l) { - int idx; - if (axis == 0) { - idx = i * m + j + l * m * n; - } else if (axis == 1) { - idx = i * m * k + j + l * m; - } else if (axis == 2) { - idx = i * m * k + j * k + l; - } else { - assert(false); + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish = input_shapes[0].dims4(); + + assert(osh[1] == 1); + + for (ark::DimType n = 0; n < ish[0]; ++n) { + for (ark::DimType h = 0; h < ish[2]; ++h) { + for (ark::DimType w = 0; w < ish[3]; ++w) { + T sum = 0; + for (ark::DimType c = 0; c < ish[1]; ++c) { + sum += input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]; } - ark::half_t x = data_a[idx]; - v += x; + out[n * osh[1] * osh[2] * osh[3] + h * osh[3] + w] = sum; } - ((ark::half_t *)gt)[i * m + j] = v; } } +} + +template +void baseline_reduce_sum_axis2(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish = input_shapes[0].dims4(); + + assert(osh[2] == 1); + + for (ark::DimType n = 0; n < ish[0]; ++n) { + for (ark::DimType c = 0; c < ish[1]; ++c) { + for (ark::DimType w = 0; w < ish[3]; ++w) { + T sum = 0; + for (ark::DimType h = 0; h < ish[2]; ++h) { + sum += input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]; + } + out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + w] = + sum; + } + } + } +}; + +template +void baseline_reduce_sum_axis3(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish = input_shapes[0].dims4(); + + assert(osh[3] == 1); - // - ark::Model model; - ark::Tensor *tns_x = nullptr; - ark::Tensor *tns_y = nullptr; - if (axis == 0) { - tns_x = model.tensor({k, n, m}, ark::FP16); - tns_y = model.tensor({1, n, m}, ark::FP16); - } else if (axis == 1) { - tns_x = model.tensor({n, k, m}, ark::FP16); - tns_y = model.tensor({n, 1, m}, ark::FP16); - } else if (axis == 2) { - tns_x = model.tensor({n, m, k}, ark::FP16); - tns_y = model.tensor({n, m, 1}, ark::FP16); - } else { - LOG(ark::ERROR, "invalid axis"); + for (ark::DimType n = 0; n < ish[0]; ++n) { + for (ark::DimType c = 0; c < ish[1]; ++c) { + for (ark::DimType h = 0; h < ish[2]; ++h) { + T sum = 0; + for (ark::DimType w = 0; w < ish[3]; ++w) { + sum += input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]; + } + out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + + h * osh[3]] = sum; + } + } } +}; - model.reduce_sum(tns_x, axis, tns_y); +ark::unittest::State test_reduce_axis0() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP32); + ark::Tensor *out = m.reduce_sum(t, /*axis=*/0); - // - ark::Executor exe{0, 0, 1, model, "test_reduce"}; - exe.compile(); + auto result = ark::op_test("reduce_axis0", m, {t}, {out}, + baseline_reduce_sum_axis0); + ark::op_test_log(result); + return ark::unittest::SUCCESS; +} - // Set data. - tns_x->write(data_a.get()); +ark::unittest::State test_reduce_axis1() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(1, 2, 4, 1024), ark::FP32); + ark::Tensor *out = m.reduce_sum(t, /*axis=*/1); - exe.launch(); - exe.run(1); - exe.stop(); + auto result = ark::op_test("reduce_axis1", m, {t}, {out}, + baseline_reduce_sum_axis1); + ark::op_test_log(result); + return ark::unittest::SUCCESS; +} - // Copy results of the loop kernel routine into CPU memory. - void *res = malloc(buf_y_sz); - UNITTEST_NE(res, (void *)nullptr); - tns_y->read(res); +ark::unittest::State test_reduce_axis2() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(1, 1, 7, 8192), ark::FP32); + ark::Tensor *out = m.reduce_sum(t, /*axis=*/2); - // Compare results with the ground truth. - auto p = - ark::utils::cmp_matrix((ark::half_t *)gt, (ark::half_t *)res, m, n); - float max_err = p.second; - LOG(ark::INFO, "reduce:", n, 'x', m, 'x', k, " axis ", axis, " ", - setprecision(4), " mse ", p.first, " max_err ", max_err * 100, "%"); + auto result = ark::op_test("reduce_axis2", m, {t}, {out}, + baseline_reduce_sum_axis2); + ark::op_test_log(result); + return ark::unittest::SUCCESS; +} - free(res); - free(gt); +ark::unittest::State test_reduce_axis3() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(1, 1, 2, 8192), ark::FP32); + ark::Tensor *out = m.reduce_sum(t, /*axis=*/3); - UNITTEST_EQ(max_err, 0.0); + auto result = ark::op_test("reduce_axis3", m, {t}, {out}, + baseline_reduce_sum_axis3); + ark::op_test_log(result); + return ark::unittest::SUCCESS; } -ark::unittest::State test_reduce() +ark::unittest::State test_reduce_axis3_padded() { - // TODO: implement reduce for axis = 0 and axis = 1 - for (int axis = 2; axis < 3; axis++) { - test_reduce_internal(1, 64, 2, axis); - test_reduce_internal(1, 64, 8, axis); - test_reduce_internal(1, 64, 9, axis); - test_reduce_internal(2, 64, 4, axis); - test_reduce_internal(8, 64, 4, axis); - test_reduce_internal(64, 64, 4, axis); - test_reduce_internal(1, 256, 256, axis); - test_reduce_internal(1024, 384, 4, axis); - } + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(1, 1, 2, 8192), ark::FP32); + ark::Tensor *out = m.tensor(ark::Dims(1, 1, 2, 1), ark::FP32, nullptr, + ark::Dims(1, 1, 2, 32)); + out = m.reduce_sum(t, /*axis=*/3, out); + + auto result = ark::op_test("reduce_axis3_padded", m, {t}, {out}, + baseline_reduce_sum_axis3); + ark::op_test_log(result); + return ark::unittest::SUCCESS; +} +ark::unittest::State test_reduce_fp16() +{ + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP16); + ark::Tensor *out = m.reduce_sum(t, /*axis=*/0); + + auto result = ark::op_test("reduce_fp16_axis0", m, {t}, {out}, + baseline_reduce_sum_axis0); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP16); + ark::Tensor *out = m.reduce_sum(t, /*axis=*/3); + + auto result = ark::op_test("reduce_fp16_axis3", m, {t}, {out}, + baseline_reduce_sum_axis3); + ark::op_test_log(result); + } return ark::unittest::SUCCESS; } int main() { ark::init(); - UNITTEST(test_reduce); + UNITTEST(test_reduce_axis0); + UNITTEST(test_reduce_axis1); + UNITTEST(test_reduce_axis2); + UNITTEST(test_reduce_axis3); + UNITTEST(test_reduce_axis3_padded); + UNITTEST(test_reduce_fp16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_scale_test.cc b/ark/ops/ops_scale_test.cc index 09bd96c62..48f096685 100644 --- a/ark/ops/ops_scale_test.cc +++ b/ark/ops/ops_scale_test.cc @@ -1,111 +1,42 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu/gpu_kernel.h" #include "include/ark.h" #include "include/ark_utils.h" -#include "logging.h" +#include "ops_test_common.h" #include "unittest/unittest_utils.h" -using namespace std; +#define SCALE_FACTOR 0.7 -// -void test_scale_internal(unsigned int bs, unsigned int n, unsigned int m, - float val = 0.7) +template +void baseline_scale(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &) { - ark::GpuMgr *mgr = ark::get_gpu_mgr(0); - ark::GpuMgrCtx *ctx = mgr->create_context("test_simple_scale", 0, 1); - - unsigned int len = bs * m * n; - ark::GpuBuf *buf_x = ctx->mem_alloc(len * sizeof(ark::half_t)); - ark::GpuBuf *buf_y = ctx->mem_alloc(len * sizeof(ark::half_t)); - - ctx->freeze(); - - ark::GpuKernel gk{"simple_scale", - {ark::unittest::get_kernel_code("simple_scale")}, - {(unsigned int)mgr->get_gpu_info().num_sm, 1, 1}, - {512, 1, 1}, - 0, - {buf_y, buf_x}, - {}, - { - {&val, sizeof(val)}, - {&len, sizeof(len)}, - }, - ""}; - gk.compile(mgr->get_gpu_info()); - gk.load(); - - // Set data. - ark::srand(); - auto data_x = ark::utils::rand_halfs(len, 0.01); - ark::gpu_memcpy(buf_x, data_x.get(), len * sizeof(ark::half_t)); - - // Run the GPU kernel. - ark::GpuStream s = ctx->create_stream(); - int ret = gk.launch(s); - UNITTEST_EQ(ret, 0); - ret = ctx->sync_stream(s); - UNITTEST_EQ(ret, 0); - - // Copy the ground truth results into CPU memory. - void *gt = malloc(len * sizeof(ark::half_t)); - UNITTEST_NE(gt, (void *)nullptr); - ark::gpu_memcpy(gt, buf_y, len * sizeof(ark::half_t)); - - mgr->destroy_context(ctx); - - // - ark::Model model; - ark::Tensor *tns_x = model.tensor({bs, n, m}, ark::FP16); - ark::Tensor *tns_y = model.scale(tns_x, val); - - // - ark::Executor exe{0, 0, 1, model, "test_scale"}; - exe.compile(); - - // Set data. - tns_x->write(data_x.get()); - - exe.launch(); - exe.run(1); - exe.stop(); - - // Copy results of the loop kernel routine into CPU memory. - void *res = malloc(len * sizeof(ark::half_t)); - UNITTEST_NE(res, (void *)nullptr); - tns_y->read(res); - - // Compare results with the ground truth. - auto p = - ark::utils::cmp_matrix((ark::half_t *)gt, (ark::half_t *)res, m, n, bs); - float max_err = p.second; - LOG(ark::INFO, "scale:", n, 'x', m, ",bs=", bs, setprecision(4), " mse ", - p.first, " max_err ", max_err * 100, "%"); - - free(res); - free(gt); - - UNITTEST_EQ(max_err, 0.0); -} - -ark::unittest::State test_scale() + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + ark::Dims osh = output_shapes[0]; + for (ark::DimType i = 0; i < osh.size(); ++i) { + out[i] = input[i] * T(SCALE_FACTOR); + } +}; + +ark::unittest::State test_scale_fp16() { - test_scale_internal(1, 1, 64); - test_scale_internal(1, 128, 128); - test_scale_internal(1, 4096, 1024); - test_scale_internal(1, 1024, 4096); - test_scale_internal(2, 1, 64); - test_scale_internal(2, 128, 128); - test_scale_internal(8, 4096, 1024); - test_scale_internal(8, 1024, 4096); + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); + + auto result = + ark::op_test("scale_fp16", m, {t}, {out}, baseline_scale); + ark::op_test_log(result); return ark::unittest::SUCCESS; } int main() { ark::init(); - UNITTEST(test_scale); + UNITTEST(test_scale_fp16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_test_common.cc b/ark/ops/ops_test_common.cc index d46cd3a70..5743851ac 100644 --- a/ark/ops/ops_test_common.cc +++ b/ark/ops/ops_test_common.cc @@ -3,140 +3,415 @@ #include "ops_test_common.h" #include "gpu/gpu_kernel.h" -#include "include/ark_utils.h" #include "logging.h" +#include "random.h" #include "unittest/unittest_utils.h" +#include +#include -using namespace std; +namespace ark { -// op_name: "add", "mul" -// TODO: deprecate this +std::ostream &operator<<(std::ostream &os, const OpsTestResult &result) +{ + os << "op test: " << result.test_name << " #warp/sm " + << result.num_warps_per_sm << ", msec/iter " << result.msec_per_iter; + os << std::setprecision(4); + for (size_t i = 0; i < result.mse.size(); i++) { + float err_pcnt = result.max_err_rate[i] * 100; + os << ", mse " << result.mse[i] << ", max_diff " << result.max_diff[i] + << ", max_err_rate " << err_pcnt << "%"; + } + return os; +} + +/// Calculate the error rate between two values. +/// @tparam T Type of the values +/// @param a First value +/// @param b Second value +/// @return The error rate +template float error_rate(T a, T b) +{ + T diff = abs(a - b); + T max = std::max(abs(a), abs(b)); + if (max == 0) { + return 0; + } + return (float)diff / (float)max; +} + +/// Calculate the error rate between two @ref half_t values. +/// @param a First value +/// @param b Second value +/// @return The error rate +float error_rate(half_t a, half_t b) +{ + return error_rate(a, b); +} + +/// Calculate the error rate between two floats. +/// @param a First value +/// @param b Second value +/// @return The error rate +float error_rate(float a, float b) +{ + return error_rate(a, b); +} + +/// Return mean squared error and max error rate between two tensors. +/// @tparam T data type of the tensors. +/// @param ground_truth ground truth data array. +/// @param res input data array to compare with the ground truth. +/// @param shape shape of the tensor. +/// @param print whether to print wrong values. +/// @return a pair of mean squared error and max error rate. template -void test_bcast_internal(string op_name, ark::TensorType type, ark::DimType bs, - ark::DimType n, ark::DimType m, bool overwrite) -{ - string type_name; - if (type == ark::FP32) { - type_name = "fp32"; - } else if (type == ark::FP16) { - type_name = "fp16"; - } else { - UNITTEST_FEXIT("Unsupported tensor type:", type); - } - string kernel_name = "simple_" + op_name + "_" + type_name; - - ark::GpuMgr *mgr = ark::get_gpu_mgr(0); - ark::GpuMgrCtx *ctx = mgr->create_context("test_simple_" + op_name, 0, 1); - - ark::DimType len = m * n; - ark::GpuBuf *buf_a = ctx->mem_alloc(bs * len * sizeof(T)); - ark::GpuBuf *buf_b = ctx->mem_alloc(len * sizeof(T)); - ark::GpuBuf *buf_c = ctx->mem_alloc(bs * len * sizeof(T)); - - ctx->freeze(); - - ark::GpuKernel gk{kernel_name, - {ark::unittest::get_kernel_code("simple_" + op_name)}, - {(unsigned int)mgr->get_gpu_info().num_sm, 1, 1}, - {512, 1, 1}, - 0, - {buf_c, buf_a, buf_b}, - {}, - { - {&bs, sizeof(bs)}, - {&len, sizeof(len)}, - }, - ""}; - gk.compile(mgr->get_gpu_info()); - gk.load(); - - // Set data. - ark::srand(); - auto data_a = ark::utils::rand_array(bs * len, 0.01); - auto data_b = ark::utils::rand_array(len, 0.01); - ark::gpu_memcpy(buf_a, data_a.get(), bs * len * sizeof(T)); - ark::gpu_memcpy(buf_b, data_b.get(), len * sizeof(T)); - - // Run the GPU kernel. - ark::GpuStream s = ctx->create_stream(); - int ret = gk.launch(s); - UNITTEST_EQ(ret, 0); - ret = ctx->sync_stream(s); - UNITTEST_EQ(ret, 0); - - // Copy the ground truth results into CPU memory. - T *gt = (T *)malloc(bs * len * sizeof(T)); - UNITTEST_NE(gt, (T *)nullptr); - ark::gpu_memcpy(gt, buf_c, bs * len * sizeof(T)); - - mgr->destroy_context(ctx); - - // - ark::Model model; - ark::Tensor *tns_a = model.tensor({bs, n, m}, type); - ark::Tensor *tns_b = model.tensor({1, n, m}, type); - ark::Tensor *tns_c = nullptr; - if (op_name == "add") { - if (overwrite) { - tns_c = model.add(tns_a, tns_b, tns_a); - } else { - tns_c = model.add(tns_a, tns_b); +TensorCompareResult tensor_compare(T *ground_truth, T *res, Dims shape, + bool print = false) +{ + DimType nelem = shape.size(); + int ndims = shape.ndims(); + float l2_loss = 0; + float max_err = 0; + float max_diff = 0; + for (DimType i = 0; i < nelem; ++i) { + float diff = (float)(ground_truth[i] - res[i]); + if (std::abs(diff) > max_diff) { + max_diff = std::abs(diff); } - } else if (op_name == "mul") { - if (overwrite) { - tns_c = model.mul(tns_a, tns_b, tns_a); - } else { - tns_c = model.mul(tns_a, tns_b); + l2_loss += diff * diff; + + float err = error_rate(ground_truth[i], res[i]); + if (err > 0.) { + if (print) { + Dims idx(std::vector(ndims, 0)); + for (int j = 0; j < ndims; ++j) { + DimType vol = 1; + for (int k = j + 1; k < ndims; ++k) { + vol *= shape[k]; + } + idx[j] = (i / vol) % shape[j]; + } + std::cout << idx << " expected " << ground_truth[i] + << ", actually " << res[i] << " (err: " << err << ")" + << std::endl; + } + if (err > max_err) { + max_err = err; + } } } - UNITTEST_NE(tns_c, (ark::Tensor *)nullptr); + TensorCompareResult result; + result.mse = l2_loss / nelem; + result.max_diff = max_diff; + result.max_error_rate = max_err; + return result; +} + +/// Return mean squared error and max error rate between two @ref half_t +/// tensors. +/// @param ground_truth ground truth data array. +/// @param res input data array to compare with the ground truth. +/// @param shape shape of the tensor. +/// @param print whether to print wrong values. +/// @return a pair of mean squared error and max error rate. +TensorCompareResult tensor_compare(half_t *ground_truth, half_t *res, + Dims shape, bool print) +{ + return tensor_compare(ground_truth, res, shape, print); +} - // - ark::Executor exe{0, 0, 1, model, "test_" + op_name + "_" + type_name}; +/// Return mean squared error and max error rate between two float tensors. +/// @param ground_truth ground truth data array. +/// @param res input data array to compare with the ground truth. +/// @param shape shape of the tensor. +/// @param print whether to print wrong values. +/// @return a pair of mean squared error and max error rate. +TensorCompareResult tensor_compare(float *ground_truth, float *res, Dims shape, + bool print) +{ + return tensor_compare(ground_truth, res, shape, print); +} + +/// Return mean squared error and max error rate between two int tensors. +/// @param ground_truth ground truth data array. +/// @param res input data array to compare with the ground truth. +/// @param shape shape of the tensor. +/// @param print whether to print wrong values. +/// @return a pair of mean squared error and max error rate. +TensorCompareResult tensor_compare(int *ground_truth, int *res, Dims shape, + bool print) +{ + return tensor_compare(ground_truth, res, shape, print); +} + +/// Return mean squared error and max error rate between two byte tensors. +/// @param ground_truth ground truth data array. +/// @param res input data array to compare with the ground truth. +/// @param shape shape of the tensor. +/// @param print whether to print wrong values. +/// @return a pair of mean squared error and max error rate. +TensorCompareResult tensor_compare(uint8_t *ground_truth, uint8_t *res, + Dims shape, bool print) +{ + return tensor_compare(ground_truth, res, shape, print); +} + +OpsTestResult op_test(const std::string &test_name_prefix, Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, const std::string &init_method, + bool print_on_error, int num_warps_per_sm) +{ + Executor exe{0, 0, 1, model, "op_test_" + rand_anum(4), num_warps_per_sm}; exe.compile(); - // Set data. - tns_a->write(data_a.get()); - tns_b->write(data_b.get()); + // Set random data. + std::vector input_data; + for (auto t : inputs) { + void *buf = ::malloc(t->shape_bytes()); + UNITTEST_NE(buf, (void *)nullptr); + + if (init_method == "random") { + if (t->type == FP32) { + ::memcpy(buf, utils::rand_floats(t->shape.size(), 0.1).get(), + t->shape_bytes()); + } else if (t->type == FP16) { + ::memcpy(buf, utils::rand_halfs(t->shape.size(), 0.1).get(), + t->shape_bytes()); + } else if (t->type == INT32) { + ::memcpy(buf, + utils::rand_array(t->shape.size(), 10000).get(), + t->shape_bytes()); + } else if (t->type == BYTE) { + ::memcpy(buf, utils::rand_bytes(t->shape.size()).get(), + t->shape_bytes()); + } else { + LOG(ERROR, "Unsupported data type: ", t->type); + } + } else if (init_method == "ones") { + if (t->type == FP32) { + ::memcpy(buf, utils::ones(t->shape.size()).get(), + t->shape_bytes()); + } else if (t->type == FP16) { + ::memcpy(buf, utils::ones(t->shape.size()).get(), + t->shape_bytes()); + } else if (t->type == INT32) { + ::memcpy(buf, utils::ones(t->shape.size()).get(), + t->shape_bytes()); + } else if (t->type == BYTE) { + ::memcpy(buf, utils::ones(t->shape.size()).get(), + t->shape_bytes()); + } else { + LOG(ERROR, "Unsupported data type: ", t->type); + } + } else { + LOG(ERROR, "Unsupported init method: ", init_method); + } + t->write(buf); + input_data.push_back(buf); + } exe.launch(); + + // Correctness test. exe.run(1); + exe.wait(); exe.stop(); // Copy results of the loop kernel routine into CPU memory. - T *res = (T *)malloc(bs * len * sizeof(T)); - UNITTEST_NE(res, (T *)nullptr); - tns_c->read(res); + std::vector res; + for (auto t : outputs) { + void *buf = ::malloc(t->shape_bytes()); + UNITTEST_NE(buf, (void *)nullptr); + t->read(buf); + res.push_back(buf); + } + + std::vector gt; + for (auto t : outputs) { + void *buf = ::malloc(t->shape_bytes()); + UNITTEST_NE(buf, (void *)nullptr); + gt.push_back(buf); + } + + std::vector output_shapes; + for (auto t : outputs) { + output_shapes.push_back(t->shape); + } + std::vector input_shapes; + for (auto t : inputs) { + input_shapes.push_back(t->shape); + } + + // Calculate ground truth. + baseline(gt, output_shapes, input_data, input_shapes); + + std::stringstream test_name; + test_name << test_name_prefix; + for (size_t i = 0; i < inputs.size(); i++) { + test_name << ";in" << i << "=" << inputs[i]->shape; + } + for (size_t i = 0; i < outputs.size(); i++) { + test_name << ";out" << i << "=" << outputs[i]->shape; + } + test_name << ";"; + + OpsTestResult result; + result.test_name = test_name.str(); + result.num_warps_per_sm = num_warps_per_sm; // Compare results with the ground truth. - std::pair p = - ark::utils::tensor_compare(gt, res, tns_c->shape, true); - float max_err = p.second; + for (size_t i = 0; i < outputs.size(); i++) { + TensorCompareResult comp; + if (outputs[i]->type == FP32) { + comp = tensor_compare(static_cast(gt[i]), + static_cast(res[i]), + outputs[i]->shape.dims4(), print_on_error); + } else if (outputs[i]->type == FP16) { + comp = tensor_compare(static_cast(gt[i]), + static_cast(res[i]), + outputs[i]->shape.dims4(), print_on_error); + } else if (outputs[i]->type == INT32) { + comp = tensor_compare(static_cast(gt[i]), + static_cast(res[i]), + outputs[i]->shape.dims4(), print_on_error); + } else if (outputs[i]->type == BYTE) { + comp = tensor_compare(static_cast(gt[i]), + static_cast(res[i]), + outputs[i]->shape.dims4(), print_on_error); + } else { + LOG(ERROR, "Unsupported data type: ", outputs[i]->type); + } + result.mse.push_back(comp.mse); + result.max_diff.push_back(comp.max_diff); + result.max_err_rate.push_back(comp.max_error_rate); + } - if (overwrite) { - tns_a->read(res); - p = ark::utils::tensor_compare(gt, res, tns_a->shape, true); - max_err = std::max(max_err, p.second); + // Throughput test. + + // Restart the executor. + exe.launch(); + + // Rough measure. + int warmup_iter = 3; + float target_msec = 2000; + exe.run(warmup_iter); + float warmup_msec = exe.stop(); + + if (warmup_msec > target_msec) { + // Warm-up was long enough. + result.msec_per_iter = warmup_msec / warmup_iter; + } else { + int iter = int(target_msec / warmup_msec); + exe.launch(); + exe.run(iter); + float msec = exe.stop(); + result.msec_per_iter = msec / iter; } - LOG(ark::INFO, op_name, ":", n, 'x', m, ",", type_name, ",bs=", bs, - ",overwrite=", overwrite, setprecision(4), " mse ", p.first, - " max_err ", max_err * 100, "%"); + exe.stop(); - free(res); - free(gt); + // Free resources + for (auto ptr : input_data) { + ::free(ptr); + } + for (auto ptr : res) { + ::free(ptr); + } + for (auto ptr : gt) { + ::free(ptr); + } - UNITTEST_EQ(max_err, 0.0); + return result; +} + +OpsTestResult op_test_8(const std::string &test_name_prefix, Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, + const std::string &init_method, bool print_on_error) +{ + return op_test(test_name_prefix, model, inputs, outputs, baseline, + init_method, 8, print_on_error); } -void test_bcast_fp32(string op_name, ark::DimType bs, ark::DimType n, - ark::DimType m, bool overwrite) +OpsTestResult op_test_16(const std::string &test_name_prefix, Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, + const std::string &init_method, bool print_on_error) { - test_bcast_internal(op_name, ark::FP32, bs, n, m, overwrite); + return op_test(test_name_prefix, model, inputs, outputs, baseline, + init_method, 16, print_on_error); } -void test_bcast_fp16(string op_name, ark::DimType bs, ark::DimType n, - ark::DimType m, bool overwrite) +OpsTestResult op_test_32(const std::string &test_name_prefix, Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, + const std::string &init_method, bool print_on_error) { - test_bcast_internal(op_name, ark::FP16, bs, n, m, overwrite); + return op_test(test_name_prefix, model, inputs, outputs, baseline, + init_method, 32, print_on_error); } + +void op_test_log(const OpsTestResult &result) +{ + LOG(INFO, result); +} + +#define CUDA_CHECK(status) \ + do { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::ostringstream oss; \ + oss << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__; \ + throw std::runtime_error(oss.str()); \ + } \ + } while (0); + +OpsTestGpuMem::OpsTestGpuMem(size_t size) : size_(size) +{ + CUDA_CHECK(cudaMalloc(&this->gpu_ptr_, size)); +} + +OpsTestGpuMem::~OpsTestGpuMem() +{ + cudaFree(this->gpu_ptr_); +} + +void *OpsTestGpuMem::get() const +{ + return this->gpu_ptr_; +} + +size_t OpsTestGpuMem::size() const +{ + return this->size_; +} + +OpsTestGpuMem to_gpu(const void *host_ptr, size_t size) +{ + OpsTestGpuMem gpu_mem(size); + CUDA_CHECK( + cudaMemcpy(gpu_mem.get(), host_ptr, size, cudaMemcpyHostToDevice)); + return gpu_mem; +} + +void *from_gpu(const OpsTestGpuMem &test_gpu_mem, void *host_ptr) +{ + if (host_ptr == nullptr) { + host_ptr = ::malloc(test_gpu_mem.size()); + } + CUDA_CHECK(cudaMemcpy(host_ptr, test_gpu_mem.get(), test_gpu_mem.size(), + cudaMemcpyDeviceToHost)); + return host_ptr; +} + +void sync_gpu() +{ + CUDA_CHECK(cudaDeviceSynchronize()); +} + +} // namespace ark diff --git a/ark/ops/ops_test_common.h b/ark/ops/ops_test_common.h index fa99e4cf4..718c97f9f 100644 --- a/ark/ops/ops_test_common.h +++ b/ark/ops/ops_test_common.h @@ -5,13 +5,119 @@ #define ARK_OPS_TEST_COMMON_H_ #include "include/ark.h" +#include "include/ark_utils.h" +#include "unittest/unittest_utils.h" +#include +#include #include -// TODO: deprecate this -void test_bcast_fp32(std::string op_name, ark::DimType bs, ark::DimType n, - ark::DimType m, bool overwrite = false); -// TODO: deprecate this -void test_bcast_fp16(std::string op_name, ark::DimType bs, ark::DimType n, - ark::DimType m, bool overwrite = false); +namespace ark { + +struct TensorCompareResult +{ + float mse; + float max_diff; + float max_error_rate; +}; + +TensorCompareResult tensor_compare(half_t *ground_truth, half_t *res, + Dims shape, bool print = false); + +/// Return mean squared error and max error rate between two float tensors. +/// @param ground_truth ground truth data array. +/// @param res input data array to compare with the ground truth. +/// @param shape shape of the tensor. +/// @param print whether to print wrong values. +/// @return a pair of mean squared error and max error rate. +TensorCompareResult tensor_compare(float *ground_truth, float *res, Dims shape, + bool print = false); + +/// Return mean squared error and max error rate between two int tensors. +/// @param ground_truth ground truth data array. +/// @param res input data array to compare with the ground truth. +/// @param shape shape of the tensor. +/// @param print whether to print wrong values. +/// @return a pair of mean squared error and max error rate. +TensorCompareResult tensor_compare(int *ground_truth, int *res, Dims shape, + bool print = false); + +/// Return mean squared error and max error rate between two byte tensors. +/// @param ground_truth ground truth data array. +/// @param res input data array to compare with the ground truth. +/// @param shape shape of the tensor. +/// @param print whether to print wrong values. +/// @return a pair of mean squared error and max error rate. +TensorCompareResult tensor_compare(uint8_t *ground_truth, uint8_t *res, + Dims shape, bool print = false); + +struct OpsTestResult +{ + std::string test_name; + int num_warps_per_sm; + float msec_per_iter; + std::vector mse; + std::vector max_diff; + std::vector max_err_rate; +}; + +std::ostream &operator<<(std::ostream &os, const OpsTestResult &result); + +class OpsTestGpuMem +{ + public: + OpsTestGpuMem(size_t size); + ~OpsTestGpuMem(); + void *get() const; + size_t size() const; + + private: + size_t size_; + void *gpu_ptr_; +}; + +/// A function that takes input arrays and returns the ground truth output +/// arrays +using OpsTestBaseline = std::function &outputs, const std::vector &output_tensors, + const std::vector &inputs, + const std::vector &input_tensors)>; + +OpsTestResult op_test(const std::string &test_name_prefix, Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, + const std::string &init_method = "random", + bool print_on_error = false, int num_warps_per_sm = 8); + +OpsTestResult op_test_8(const std::string &test_name_prefix, Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, + const std::string &init_method = "random", + bool print_on_error = false); + +OpsTestResult op_test_16(const std::string &test_name_prefix, Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, + const std::string &init_method = "random", + bool print_on_error = false); + +OpsTestResult op_test_32(const std::string &test_name_prefix, Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, + const std::string &init_method = "random", + bool print_on_error = false); + +void op_test_log(const OpsTestResult &result); + +OpsTestGpuMem to_gpu(const void *host_ptr, size_t size); + +void *from_gpu(const OpsTestGpuMem &test_gpu_mem, void *host_ptr = nullptr); + +void sync_gpu(); + +} // namespace ark #endif // ARK_OPS_TEST_COMMON_H_ diff --git a/ark/random.cc b/ark/random.cc index ba5f49e98..fb960dc67 100644 --- a/ark/random.cc +++ b/ark/random.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include #include #include #include @@ -27,4 +28,19 @@ int rand() return ::rand(); } +// Generate a random alpha-numeric string. +std::string rand_anum(size_t len) +{ + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = sizeof(charset) - 1; + return charset[rand() % max_index]; + }; + std::string str(len, 0); + std::generate_n(str.begin(), len, randchar); + return str; +} + } // namespace ark diff --git a/ark/random.h b/ark/random.h new file mode 100644 index 000000000..78c5f2a2a --- /dev/null +++ b/ark/random.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_RANDOM_H_ +#define ARK_RANDOM_H_ + +namespace ark { + +// Generate a random alpha-numeric string. +std::string rand_anum(size_t len); + +} // namespace ark + +#endif // ARK_RANDOM_H_ diff --git a/ark/sched/sched_test.cc b/ark/sched/sched_test.cc index 2fad4a17a..d336481f8 100644 --- a/ark/sched/sched_test.cc +++ b/ark/sched/sched_test.cc @@ -5,6 +5,7 @@ #include "include/ark.h" #include "include/ark_utils.h" #include "logging.h" +#include "ops/ops_test_common.h" #include "sched/sched.h" #include "unittest/unittest_utils.h" @@ -354,12 +355,11 @@ ark::unittest::State test_sched_comp_baseline() ark::unittest::wait_all_processes(); // TODO: the output data are set on different processes, we need to copy // run the test on the same process - auto p = - ark::utils::cmp_matrix((ark::half_t *)output_data1, - (ark::half_t *)output_data2, channel, units); + auto comp = tensor_compare(output_data1, output_data2, + ark::Dims(batch_size, units, units)); LOG(ark::INFO, " scheduler compare test: ", " total_bytes: ", bytes, - " iter: ", 1, setprecision(4), " mse: ", p.first, - " max_err: ", p.second * 100, "%"); + " iter: ", 1, setprecision(4), " mse: ", comp.mse, + " max_err: ", comp.max_error_rate * 100, "%"); return unittest::SUCCESS; } diff --git a/ark/tensor.cc b/ark/tensor.cc index f68bc3e99..965310bcd 100644 --- a/ark/tensor.cc +++ b/ark/tensor.cc @@ -252,7 +252,7 @@ void Tensor::write(const void *buf) assert(done == bytes); } -void Tensor::read(void *buf) +void *Tensor::read(void *buf) { GpuBuf *gbuf = static_cast(this->buf->buf); if (gbuf == nullptr) { @@ -260,10 +260,16 @@ void Tensor::read(void *buf) } size_t bytes = this->shape_bytes(); int ndims = this->ndims(); + if (buf == nullptr) { + buf = ::malloc(bytes); + if (buf == nullptr) { + LOG(ERROR, "failed to allocate host buffer"); + } + } char *ptr = (char *)buf; if (ndims == 1) { gpu_memcpy(ptr, gbuf->ref(this->offset_bytes(0)), bytes); - return; + return ptr; } size_t done = 0; size_t rem = bytes; @@ -307,6 +313,24 @@ void Tensor::read(void *buf) } assert(rem == 0); assert(done == bytes); + return buf; +} + +void *Tensor::read_raw(void *buf) +{ + GpuBuf *gbuf = static_cast(this->buf->buf); + if (gbuf == nullptr) { + LOG(ERROR, "failed to get GPU buffer for tensor ", this->id); + } + size_t bytes = this->ldims_bytes(); + if (buf == nullptr) { + buf = ::malloc(bytes); + if (buf == nullptr) { + LOG(ERROR, "failed to allocate host buffer"); + } + } + gpu_memcpy(buf, gbuf->ref(this->offset_bytes(0)), bytes); + return buf; } void Tensor::clear() diff --git a/ark/utils.cc b/ark/utils.cc index abfacb1e3..d111c9a2e 100644 --- a/ark/utils.cc +++ b/ark/utils.cc @@ -123,6 +123,14 @@ unique_ptr rand_floats(size_t num, float max_val) return rand_array(num, max_val); } +/// Return a random bytes array. +/// @param num Number of elements +/// @return std::unique_ptr +unique_ptr rand_bytes(size_t num) +{ + return rand_array(num, 255); +} + /// Return an array of values starting from `begin` with difference `diff`. /// @tparam T Type of the array /// @param num Number of elements @@ -160,255 +168,6 @@ unique_ptr range_floats(size_t num, float begin, float diff) return range_array(num, begin, diff); } -/// Calculate the error rate between two values. -/// @tparam T Type of the values -/// @param a First value -/// @param b Second value -/// @return The error rate -template float error_rate(T a, T b) -{ - T diff = abs(a - b); - if (diff < numeric_limits::min()) { - return 0; - } - diff -= numeric_limits::epsilon(); - T half_eps = numeric_limits::epsilon() * T(0.5); - if (a > b) { - a -= half_eps; - b += half_eps; - } else { - a += half_eps; - b -= half_eps; - } - return (float)diff / max(abs((float)a), abs((float)b)); -} - -/// Calculate the error rate between two @ref half_t values. -/// @param a First value -/// @param b Second value -/// @return The error rate -float error_rate(half_t a, half_t b) -{ - return error_rate(a, b); -} - -/// Calculate the error rate between two floats. -/// @param a First value -/// @param b Second value -/// @return The error rate -float error_rate(float a, float b) -{ - return error_rate(a, b); -} - -/// Return mean squared error and max error rate between two matrices. -template -pair cmp_matrix(T *ground_truth, T *res, unsigned int m, - unsigned int n, unsigned int bs, unsigned int lm, - unsigned int ln, bool print) -{ - // TODO: deprecate this function. - - if (lm == 0) { - lm = m; - } - if (ln == 0) { - ln = n; - } - size_t num = (size_t)lm * (size_t)ln; - - const float thres_err = 0.01; - - float l2_loss = 0; - float max_err = 0; - // int cnt_flip = 0; - // float max_err_gv; - // float max_err_rv; - for (unsigned int bidx = 0; bidx < bs; ++bidx) { - for (unsigned int nidx = 0; nidx < n; ++nidx) { - for (unsigned int midx = 0; midx < m; ++midx) { - unsigned int idx = midx + nidx * lm + bidx * lm * ln; - T gv = ground_truth[idx]; - T rv = res[idx]; - float diff = (float)(gv - rv); - l2_loss += diff * diff; - float err = error_rate(gv, rv); - // if ((err > thres_err) && (error_rate(gv, -rv) < - // thres_err) && - // (((float)gv * (float)rv) < 0)) { - // cnt_flip++; - // cout << (float)gv << "," << (float)rv << endl; - // cout << hex << gv.storage << "," << rv.storage << dec << - // endl; - // } - if (err > max_err) { - max_err = err; - // max_err_gv = (float)gv; - // max_err_rv = (float)rv; - } - } - } - } - if (print) { - unsigned int x = 0; - unsigned int cc = 0; - cout << setprecision(4); - for (unsigned int bidx = 0; bidx < bs; ++bidx) { - for (unsigned int nidx = 0; nidx < n; ++nidx) { - for (unsigned int midx = 0; midx < m; ++midx) { - unsigned int idx = midx + nidx * lm + bidx * lm * ln; - T exp = ground_truth[idx]; - T act = res[idx]; - if (error_rate(exp, act) < thres_err) { - cout << (float)act << ','; - } else { - cout << "\033[0;31m" << (float)act << "\033[0m," - << "\033[0;32m" << (float)exp << "\033[0m,"; - } - if (++cc == m) { - cout << '[' << x << ']' << endl; - cc = 0; - x++; - } - } - } - } - } - // cout << max_err_gv << endl; - // cout << max_err_rv << endl; - // cout << cnt_flip << endl; - return {l2_loss / num, max_err}; -} - -// -pair cmp_matrix(half_t *ground_truth, half_t *res, unsigned int m, - unsigned int n, unsigned int bs, unsigned int lm, - unsigned int ln, bool print) -{ - // TODO: deprecate this function. - - return cmp_matrix(ground_truth, res, m, n, bs, lm, ln, print); -} - -// -pair cmp_matrix(float *ground_truth, float *res, unsigned int m, - unsigned int n, unsigned int bs, unsigned int lm, - unsigned int ln, bool print) -{ - // TODO: deprecate this function. - - return cmp_matrix(ground_truth, res, m, n, bs, lm, ln, print); -} - -// -template -void print_matrix(T *val, unsigned int m, unsigned int n, unsigned int bs, - unsigned int lm, unsigned int ln) -{ - // TODO: deprecate this function. - - unsigned int x = 0; - unsigned int cc = 0; - cout << setprecision(4); - for (unsigned int bidx = 0; bidx < bs; ++bidx) { - for (unsigned int nidx = 0; nidx < n; ++nidx) { - for (unsigned int midx = 0; midx < m; ++midx) { - unsigned int idx = midx + nidx * lm + bidx * lm * ln; - T v = val[idx]; - cout << (float)v << ','; - if (++cc == m) { - cout << '[' << x << ']' << endl; - cc = 0; - x++; - } - } - } - } -} - -void print_matrix(half_t *val, unsigned int m, unsigned int n, unsigned int bs, - unsigned int lm, unsigned int ln) -{ - // TODO: deprecate this function. - - print_matrix(val, m, n, bs, lm, ln); -} - -void print_matrix(float *val, unsigned int m, unsigned int n, unsigned int bs, - unsigned int lm, unsigned int ln) -{ - // TODO: deprecate this function. - - print_matrix(val, m, n, bs, lm, ln); -} - -/// Return mean squared error and max error rate between two tensors. -/// @tparam T data type of the tensors. -/// @param ground_truth ground truth data array. -/// @param res input data array to compare with the ground truth. -/// @param shape shape of the tensor. -/// @param print whether to print wrong values. -/// @return a pair of mean squared error and max error rate. -template -std::pair tensor_compare(T *ground_truth, T *res, Dims shape, - bool print = false) -{ - DimType nelem = shape.size(); - int ndims = shape.ndims(); - float l2_loss = 0; - float max_err = 0; - for (DimType i = 0; i < nelem; ++i) { - float diff = (float)(ground_truth[i] - res[i]); - l2_loss += diff * diff; - - float err = error_rate(ground_truth[i], res[i]); - if (err > 0.) { - if (print) { - Dims idx; - for (int j = 0; j < ndims; ++j) { - DimType vol = 1; - for (int k = j + 1; k < ndims; ++k) { - vol *= shape[k]; - } - idx[j] = (i / vol) % shape[j]; - } - std::cout << idx << " expected " << ground_truth[i] - << ", actually " << res[i] << " (err: " << err << ")" - << std::endl; - } - if (err > max_err) { - max_err = err; - } - } - } - return {l2_loss / nelem, max_err}; -} - -/// Return mean squared error and max error rate between two @ref half_t -/// tensors. -/// @param ground_truth ground truth data array. -/// @param res input data array to compare with the ground truth. -/// @param shape shape of the tensor. -/// @param print whether to print wrong values. -/// @return a pair of mean squared error and max error rate. -std::pair tensor_compare(half_t *ground_truth, half_t *res, - Dims shape, bool print = false) -{ - return tensor_compare(ground_truth, res, shape, print); -} - -/// Return mean squared error and max error rate between two float tensors. -/// @param ground_truth ground truth data array. -/// @param res input data array to compare with the ground truth. -/// @param shape shape of the tensor. -/// @param print whether to print wrong values. -/// @return a pair of mean squared error and max error rate. -std::pair tensor_compare(float *ground_truth, float *res, - Dims shape, bool print = false) -{ - return tensor_compare(ground_truth, res, shape, print); -} - /// Spawn a process that runs `func`. /// @param func function to run in the spawned process. /// @return PID of the spawned process.