diff --git a/distance.hpp b/distance.hpp index 54f14c6..b76c04a 100644 --- a/distance.hpp +++ b/distance.hpp @@ -66,10 +66,10 @@ static avs::matf32_t avx512_subtract_batch(avs::vecf32_t const &query, int32_t const N = data.size(); int32_t const dim = data[0].size(); avs::matf32_t result(N, avs::vecf32_t(dim, 0.0f)); - for (int i = 0; i < N; i++) { + for (int32_t i = 0; i < N; i++) { float PORTABLE_ALIGN64 tmp[dim]; avx512_subtract(query.data(), data[i].data(), tmp); - for (int k = 0; k < dim; k++) { + for (int32_t k = 0; k < dim; k++) { result[i][k] = tmp[k]; } } @@ -104,7 +104,7 @@ static avs::vecf32_t amx_matmul(const int32_t &r, const int32_t &c, auto c_mem = dnnl::memory(pd.dst_desc(), engine); auto prim = dnnl::matmul(pd); - std::unordered_map args; + std::unordered_map args; args.insert({DNNL_ARG_SRC, a_mem}); args.insert({DNNL_ARG_WEIGHTS, b_mem}); args.insert({DNNL_ARG_DST, c_mem}); @@ -112,11 +112,11 @@ static avs::vecf32_t amx_matmul(const int32_t &r, const int32_t &c, stream.wait(); read_from_dnnl_memory(dst.data(), c_mem); - avs::vecf32_t res(r, 0.0f); - for (int i = 0; i < r; i++) { - res[i] = dst[i * r + i]; + avs::vecf32_t result(r, 0.0f); + for (int32_t i = 0; i < r; i++) { + result[i] = dst[i * r + i]; } - return res; + return result; } static matf32_t amx_inner_product(int32_t const &n, int32_t const &oc, @@ -149,7 +149,7 @@ static matf32_t amx_inner_product(int32_t const &n, int32_t const &oc, auto dst_mem = dnnl::memory(pd.dst_desc(), engine); auto prim = dnnl::inner_product_forward(pd); - std::unordered_map args; + std::unordered_map args; args.insert({DNNL_ARG_SRC, s_mem}); args.insert({DNNL_ARG_WEIGHTS, w_mem}); args.insert({DNNL_ARG_DST, dst_mem}); @@ -158,13 +158,13 @@ static matf32_t amx_inner_product(int32_t const &n, int32_t const &oc, avs::vecf32_t dst(n * oc, 0.0f); read_from_dnnl_memory(dst.data(), dst_mem); - avs::matf32_t res(n, avs::vecf32_t(oc, 0.0f)); - for (int i = 0; i < n; i++) { - for (int j = 0; j < oc; j++) { - res[i][j] = dst[i * oc + j]; + avs::matf32_t result(n, avs::vecf32_t(oc, 0.0f)); + for (int32_t i = 0; i < n; i++) { + for (int32_t j = 0; j < oc; j++) { + result[i][j] = dst[i * oc + j]; } } - return res; + return result; } static avs::matf32_t ip_distance_amx(avs::matf32_t const &queries, @@ -176,13 +176,13 @@ static avs::matf32_t ip_distance_amx(avs::matf32_t const &queries, int32_t const ic = queries[0].size(); avs::vecf32_t queries_1d(n * ic); avs::vecf32_t batch_1d(oc * ic); - for (int i = 0; i < n; i++) { - for (int j = 0; j < ic; j++) { + for (int32_t i = 0; i < n; i++) { + for (int32_t j = 0; j < ic; j++) { queries_1d[i * ic + j] = queries[i][j]; } } - for (int i = 0; i < oc; i++) { - for (int j = 0; j < ic; j++) { + for (int32_t i = 0; i < oc; i++) { + for (int32_t j = 0; j < ic; j++) { batch_1d[i * ic + j] = batch[i][j]; } } @@ -197,20 +197,20 @@ static avs::vecf32_t l2_distance_amx(avs::vecf32_t const &query, int32_t const dim = batch[0].size(); avs::matf32_t dis_2d = avx512_subtract_batch(query, batch); avs::matf32_t dis_2d_t(dis_2d[0].size(), avs::vecf32_t(dis_2d.size(), 0.0f)); - for (int i = 0; i < dis_2d_t.size(); i++) { - for (int j = 0; j < dis_2d_t[0].size(); j++) { + for (int32_t i = 0; i < dis_2d_t.size(); i++) { + for (int32_t j = 0; j < dis_2d_t[0].size(); j++) { dis_2d_t[i][j] = dis_2d[j][i]; } } avs::vecf32_t dis_1d(batch_size * dim); - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < dim; j++) { + for (int32_t i = 0; i < batch_size; i++) { + for (int32_t j = 0; j < dim; j++) { dis_1d[i * dim + j] = dis_2d[i][j]; } } avs::vecf32_t dis_1d_t(batch_size * dim); - for (int i = 0; i < dim; i++) { - for (int j = 0; j < batch_size; j++) { + for (int32_t i = 0; i < dim; i++) { + for (int32_t j = 0; j < batch_size; j++) { dis_1d_t[i * batch_size + j] = dis_2d_t[i][j]; } } @@ -220,25 +220,25 @@ static avs::vecf32_t l2_distance_amx(avs::vecf32_t const &query, static float L2Sqr(void const *vec1, void const *vec2, int32_t const &dim) { float *v1 = (float *)vec1; float *v2 = (float *)vec2; - float res = 0; + float result = 0; for (int32_t i = 0; i < dim; i++) { float t = *v1 - *v2; v1++; v2++; - res += t * t; + result += t * t; } - return (res); + return (result); } static float InnerProduct(void const *vec1, void const *vec2, int32_t const &dim) { float *v1 = (float *)vec1; float *v2 = (float *)vec2; - float res = 0; + float result = 0; for (int32_t i = 0; i < dim; i++) { - res += ((float *)v1)[i] * ((float *)v2)[i]; + result += ((float *)v1)[i] * ((float *)v2)[i]; } - return res; + return result; } static avs::vecf32_t l2_distance_vanilla(avs::vecf32_t const &query, @@ -246,12 +246,12 @@ static avs::vecf32_t l2_distance_vanilla(avs::vecf32_t const &query, dnnl::engine &engine, dnnl::stream &stream) { int32_t const dim = batch[0].size(); - avs::vecf32_t res(batch.size()); - for (int i = 0; i < batch.size(); i++) { + avs::vecf32_t result(batch.size()); + for (int32_t i = 0; i < batch.size(); i++) { auto d = L2Sqr(query.data(), batch[i].data(), dim); - res[i] = d; + result[i] = d; } - return res; + return result; } static avs::vecf32_t ip_distance_vanilla(avs::vecf32_t const &query, @@ -259,12 +259,12 @@ static avs::vecf32_t ip_distance_vanilla(avs::vecf32_t const &query, dnnl::engine &engine, dnnl::stream &stream) { int32_t const dim = batch[0].size(); - avs::vecf32_t res(batch.size()); - for (int i = 0; i < batch.size(); i++) { + avs::vecf32_t result(batch.size()); + for (int32_t i = 0; i < batch.size(); i++) { auto d = InnerProduct(query.data(), batch[i].data(), dim); - res[i] = d; + result[i] = d; } - return res; + return result; } } // namespace avs