Skip to content

Commit

Permalink
Format simd APIs for better readability (#1088)
Browse files Browse the repository at this point in the history
Signed-off-by: CaiYudong <[email protected]>
  • Loading branch information
cydrain authored Feb 20, 2025
1 parent 2a43227 commit 5587dda
Show file tree
Hide file tree
Showing 12 changed files with 3,607 additions and 3,493 deletions.
1,071 changes: 540 additions & 531 deletions src/simd/distances_avx.cc

Large diffs are not rendered by default.

105 changes: 58 additions & 47 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,10 @@ namespace faiss {
float
fvec_L2sqr_avx(const float* x, const float* y, size_t d);

float
fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d);

float
fp16_vec_L2sqr_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
bf16_vec_L2sqr_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

/// inner product
float
fvec_inner_product_avx(const float* x, const float* y, size_t d);

float
fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d);

float
fp16_vec_inner_product_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
bf16_vec_inner_product_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

/// L1 distance
float
fvec_L1_avx(const float* x, const float* y, size_t d);
Expand All @@ -60,59 +42,88 @@ void
fvec_inner_product_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_inner_product_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);
float
fvec_norm_L2sqr_avx(const float* x, size_t d);

void
fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);
fvec_L2sqr_ny_avx(float* dis, const float* x, const float* y, size_t d, size_t ny);

void
bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);
size_t
fvec_L2sqr_ny_nearest_avx(float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny);

///////////////////////////////////////////////////////////////////////////////
// for hnsw sq, obsolete

int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d);

///////////////////////////////////////////////////////////////////////////////
// fp16

float
fvec_norm_L2sqr_avx(const float* x, size_t d);
fp16_vec_inner_product_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
fp16_vec_L2sqr_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d);

void
fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);
void
fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

///////////////////////////////////////////////////////////////////////////////
// bf16

float
bf16_vec_inner_product_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

float
bf16_vec_L2sqr_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

float
bf16_vec_norm_L2sqr_avx(const knowhere::bf16* x, size_t d);

void
fvec_L2sqr_ny_avx(float* dis, const float* x, const float* y, size_t d, size_t ny);
bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

size_t
fvec_L2sqr_ny_nearest_avx(float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny);
void
bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

///////////////////////////////////////////////////////////////////////////////
// for cardinal

float
fvec_inner_product_bf16_patch_avx(const float* x, const float* y, size_t d);

float
fvec_L2sqr_bf16_patch_avx(const float* x, const float* y, size_t d);

void
fvec_inner_product_batch_4_bf16_patch_avx(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fvec_L2sqr_batch_4_bf16_patch_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

} // namespace faiss

Expand Down
Loading

0 comments on commit 5587dda

Please sign in to comment.