diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp index 52575d6..ccae4db 100644 --- a/csrc/include/adam_cpu.hpp +++ b/csrc/include/adam_cpu.hpp @@ -141,16 +141,12 @@ void bf16_from_fp32_value_launcher( std::uintptr_t param_fp32, std::uintptr_t param_bf16 ){ - int span = 1; auto param_fp32_ptr = reinterpret_cast(param_fp32); auto param_bf16_ptr = reinterpret_cast(param_bf16); parallel_for(0, n, 0, [&](int64_t start, int64_t end) { - for (int64_t j = start; j < end; j += span) { - for (int64_t i = j; i < end; i++) { - float p = param_fp32_ptr[i]; - param_bf16_ptr[i] = bf16_from_fp32_value(p); - } - break; // must break here + for (int64_t i = start; i < end; i++) { + float p = param_fp32_ptr[i]; + param_bf16_ptr[i] = bf16_from_fp32_value(p); } }); } @@ -160,16 +156,12 @@ void fp16_from_fp32_value_launcher( std::uintptr_t param_fp32, std::uintptr_t param_fp16 ){ - int span = 1; auto param_fp32_ptr = reinterpret_cast(param_fp32); auto param_fp16_ptr = reinterpret_cast(param_fp16); parallel_for(0, n, 0, [&](int64_t start, int64_t end) { - for (int64_t j = start; j < end; j += span) { - for (int64_t i = j; i < end; i++) { - float p = param_fp32_ptr[i]; - param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); - } - break; // must break here + for (int64_t i = start; i < end; i++) { + float p = param_fp32_ptr[i]; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); } }); } @@ -195,33 +187,29 @@ void adam_cpu_0( float bias_correction1, float bias_correction2 ){ - int64_t span = 1; float sum_sq_delta = 0; float sum_delta = 0; std::mutex delta_mutex; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { float sum_sq_delta_i = 0; float sum_delta_i = 0; - for (int64_t j = start; j < end; j += span) { - for (int64_t i = j; i < end; i++) { - float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; - float m = m_fp32_ptr[i]; - float v = v_fp32_ptr[i]; - float p = param_fp32_ptr[i]; - m = beta1 * m + (1 - beta1) * g; - v = beta2 * v + (1 - beta2) * g * g; - if (delta_info_ptr != NULL){ - float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; - sum_delta_i += delta; - sum_sq_delta_i += delta * delta; - } - p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; - param_fp32_ptr[i] = p; - param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); - m_fp32_ptr[i] = m; - v_fp32_ptr[i] = v; + for (int64_t i = start; i < end; i++) { + float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; } - break; // must break here + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; } if (delta_info_ptr != NULL){ delta_mutex.lock(); @@ -253,33 +241,29 @@ void adam_cpu_bf16_0( float bias_correction1, float bias_correction2 ){ - int64_t span = 1; float sum_sq_delta = 0; float sum_delta = 0; std::mutex delta_mutex; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { float sum_sq_delta_i = 0; float sum_delta_i = 0; - for (int64_t j = start; j < end; j += span) { - for (int64_t i = j; i < end; i++) { - float g = bf16_to_fp32_value(g_bf16_ptr[i]) / scale; - float m = m_fp32_ptr[i]; - float v = v_fp32_ptr[i]; - float p = param_fp32_ptr[i]; - m = beta1 * m + (1 - beta1) * g; - v = beta2 * v + (1 - beta2) * g * g; - if (delta_info_ptr != NULL){ - float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; - sum_delta_i += delta; - sum_sq_delta_i += delta * delta; - } - p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; - param_fp32_ptr[i] = p; - param_bf16_ptr[i] = bf16_from_fp32_value(p); - m_fp32_ptr[i] = m; - v_fp32_ptr[i] = v; + for (int64_t i = start; i < end; i++) { + float g = bf16_to_fp32_value(g_bf16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; } - break; // must break here + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_bf16_ptr[i] = bf16_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; } if (delta_info_ptr != NULL){ delta_mutex.lock();