Skip to content

Commit

Permalink
use int32 everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
JayjeetAtGithub committed Nov 15, 2024
1 parent c34c72e commit ec6bb80
Showing 1 changed file with 37 additions and 37 deletions.
74 changes: 37 additions & 37 deletions distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ static avs::matf32_t avx512_subtract_batch(avs::vecf32_t const &query,
int32_t const N = data.size();
int32_t const dim = data[0].size();
avs::matf32_t result(N, avs::vecf32_t(dim, 0.0f));
for (int i = 0; i < N; i++) {
for (int32_t i = 0; i < N; i++) {
float PORTABLE_ALIGN64 tmp[dim];
avx512_subtract(query.data(), data[i].data(), tmp);
for (int k = 0; k < dim; k++) {
for (int32_t k = 0; k < dim; k++) {
result[i][k] = tmp[k];
}
}
Expand Down Expand Up @@ -104,19 +104,19 @@ static avs::vecf32_t amx_matmul(const int32_t &r, const int32_t &c,
auto c_mem = dnnl::memory(pd.dst_desc(), engine);

auto prim = dnnl::matmul(pd);
std::unordered_map<int, dnnl::memory> args;
std::unordered_map<int32_t, dnnl::memory> args;
args.insert({DNNL_ARG_SRC, a_mem});
args.insert({DNNL_ARG_WEIGHTS, b_mem});
args.insert({DNNL_ARG_DST, c_mem});
prim.execute(stream, args);
stream.wait();

read_from_dnnl_memory(dst.data(), c_mem);
avs::vecf32_t res(r, 0.0f);
for (int i = 0; i < r; i++) {
res[i] = dst[i * r + i];
avs::vecf32_t result(r, 0.0f);
for (int32_t i = 0; i < r; i++) {
result[i] = dst[i * r + i];
}
return res;
return result;
}

static matf32_t amx_inner_product(int32_t const &n, int32_t const &oc,
Expand Down Expand Up @@ -149,7 +149,7 @@ static matf32_t amx_inner_product(int32_t const &n, int32_t const &oc,
auto dst_mem = dnnl::memory(pd.dst_desc(), engine);

auto prim = dnnl::inner_product_forward(pd);
std::unordered_map<int, dnnl::memory> args;
std::unordered_map<int32_t, dnnl::memory> args;
args.insert({DNNL_ARG_SRC, s_mem});
args.insert({DNNL_ARG_WEIGHTS, w_mem});
args.insert({DNNL_ARG_DST, dst_mem});
Expand All @@ -158,13 +158,13 @@ static matf32_t amx_inner_product(int32_t const &n, int32_t const &oc,

avs::vecf32_t dst(n * oc, 0.0f);
read_from_dnnl_memory(dst.data(), dst_mem);
avs::matf32_t res(n, avs::vecf32_t(oc, 0.0f));
for (int i = 0; i < n; i++) {
for (int j = 0; j < oc; j++) {
res[i][j] = dst[i * oc + j];
avs::matf32_t result(n, avs::vecf32_t(oc, 0.0f));
for (int32_t i = 0; i < n; i++) {
for (int32_t j = 0; j < oc; j++) {
result[i][j] = dst[i * oc + j];
}
}
return res;
return result;
}

static avs::matf32_t ip_distance_amx(avs::matf32_t const &queries,
Expand All @@ -176,13 +176,13 @@ static avs::matf32_t ip_distance_amx(avs::matf32_t const &queries,
int32_t const ic = queries[0].size();
avs::vecf32_t queries_1d(n * ic);
avs::vecf32_t batch_1d(oc * ic);
for (int i = 0; i < n; i++) {
for (int j = 0; j < ic; j++) {
for (int32_t i = 0; i < n; i++) {
for (int32_t j = 0; j < ic; j++) {
queries_1d[i * ic + j] = queries[i][j];
}
}
for (int i = 0; i < oc; i++) {
for (int j = 0; j < ic; j++) {
for (int32_t i = 0; i < oc; i++) {
for (int32_t j = 0; j < ic; j++) {
batch_1d[i * ic + j] = batch[i][j];
}
}
Expand All @@ -197,20 +197,20 @@ static avs::vecf32_t l2_distance_amx(avs::vecf32_t const &query,
int32_t const dim = batch[0].size();
avs::matf32_t dis_2d = avx512_subtract_batch(query, batch);
avs::matf32_t dis_2d_t(dis_2d[0].size(), avs::vecf32_t(dis_2d.size(), 0.0f));
for (int i = 0; i < dis_2d_t.size(); i++) {
for (int j = 0; j < dis_2d_t[0].size(); j++) {
for (int32_t i = 0; i < dis_2d_t.size(); i++) {
for (int32_t j = 0; j < dis_2d_t[0].size(); j++) {
dis_2d_t[i][j] = dis_2d[j][i];
}
}
avs::vecf32_t dis_1d(batch_size * dim);
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < dim; j++) {
for (int32_t i = 0; i < batch_size; i++) {
for (int32_t j = 0; j < dim; j++) {
dis_1d[i * dim + j] = dis_2d[i][j];
}
}
avs::vecf32_t dis_1d_t(batch_size * dim);
for (int i = 0; i < dim; i++) {
for (int j = 0; j < batch_size; j++) {
for (int32_t i = 0; i < dim; i++) {
for (int32_t j = 0; j < batch_size; j++) {
dis_1d_t[i * batch_size + j] = dis_2d_t[i][j];
}
}
Expand All @@ -220,51 +220,51 @@ static avs::vecf32_t l2_distance_amx(avs::vecf32_t const &query,
static float L2Sqr(void const *vec1, void const *vec2, int32_t const &dim) {
float *v1 = (float *)vec1;
float *v2 = (float *)vec2;
float res = 0;
float result = 0;
for (int32_t i = 0; i < dim; i++) {
float t = *v1 - *v2;
v1++;
v2++;
res += t * t;
result += t * t;
}
return (res);
return (result);
}

static float InnerProduct(void const *vec1, void const *vec2,
int32_t const &dim) {
float *v1 = (float *)vec1;
float *v2 = (float *)vec2;
float res = 0;
float result = 0;
for (int32_t i = 0; i < dim; i++) {
res += ((float *)v1)[i] * ((float *)v2)[i];
result += ((float *)v1)[i] * ((float *)v2)[i];
}
return res;
return result;
}

static avs::vecf32_t l2_distance_vanilla(avs::vecf32_t const &query,
avs::matf32_t const &batch,
dnnl::engine &engine,
dnnl::stream &stream) {
int32_t const dim = batch[0].size();
avs::vecf32_t res(batch.size());
for (int i = 0; i < batch.size(); i++) {
avs::vecf32_t result(batch.size());
for (int32_t i = 0; i < batch.size(); i++) {
auto d = L2Sqr(query.data(), batch[i].data(), dim);
res[i] = d;
result[i] = d;
}
return res;
return result;
}

static avs::vecf32_t ip_distance_vanilla(avs::vecf32_t const &query,
avs::matf32_t const &batch,
dnnl::engine &engine,
dnnl::stream &stream) {
int32_t const dim = batch[0].size();
avs::vecf32_t res(batch.size());
for (int i = 0; i < batch.size(); i++) {
avs::vecf32_t result(batch.size());
for (int32_t i = 0; i < batch.size(); i++) {
auto d = InnerProduct(query.data(), batch[i].data(), dim);
res[i] = d;
result[i] = d;
}
return res;
return result;
}

} // namespace avs

0 comments on commit ec6bb80

Please sign in to comment.