Skip to content

Commit

Permalink
Improve: cBLAS Bilinear Form benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Nov 24, 2024
1 parent 9dc2a69 commit 1e5b9d7
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions scripts/bench.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,41 @@ void vdot_f64c_blas(simsimd_f64c_t const *a, simsimd_f64c_t const *b, simsimd_si
cblas_zdotc_sub((int)n, (simsimd_f64_t const *)a, 1, (simsimd_f64_t const *)b, 1, result);
}

void bilinear_f32_blas(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, simsimd_size_t n,
simsimd_distance_t *result) {
simsimd_f32_t intermediate[n];
simsimd_f32_t alpha = 1.0f, beta = 0.0f;
cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)n, (int)n, alpha, c, (int)n, b, 1, beta, intermediate, 1);
*result = cblas_sdot((int)n, a, 1, intermediate, 1);
}

void bilinear_f64_blas(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, simsimd_size_t n,
simsimd_distance_t *result) {
simsimd_f64_t intermediate[n];
simsimd_f64_t alpha = 1.0, beta = 0.0;
cblas_dgemv(CblasRowMajor, CblasNoTrans, (int)n, (int)n, alpha, c, n, b, 1, beta, intermediate, 1);
*result = cblas_ddot((int)n, a, 1, intermediate, 1);
}

void bilinear_f32c_blas(simsimd_f32c_t const *a, simsimd_f32c_t const *b, simsimd_f32c_t const *c, simsimd_size_t n,
simsimd_distance_t *results) {
simsimd_f32c_t intermediate[n];
simsimd_f32c_t alpha = {1.0f, 0.0f}, beta = {0.0f, 0.0f};
cblas_cgemv(CblasRowMajor, CblasNoTrans, (int)n, (int)n, &alpha, c, n, b, 1, &beta, intermediate, 1);
simsimd_f32_t f32_result[2] = {0, 0};
cblas_cdotu_sub((int)n, (simsimd_f32_t const *)a, 1, (simsimd_f32_t const *)intermediate, 1, f32_result);
results[0] = f32_result[0];
results[1] = f32_result[1];
}

void bilinear_f64c_blas(simsimd_f64c_t const *a, simsimd_f64c_t const *b, simsimd_f64c_t const *c, simsimd_size_t n,
simsimd_distance_t *results) {
simsimd_f64c_t intermediate[n];
simsimd_f64c_t alpha = {1.0, 0.0}, beta = {0.0, 0.0};
cblas_zgemv(CblasRowMajor, CblasNoTrans, (int)n, (int)n, &alpha, c, n, b, 1, &beta, intermediate, 1);
cblas_zdotu_sub((int)n, (simsimd_f64_t const *)a, 1, (simsimd_f64_t const *)intermediate, 1, results);
}

#endif

int main(int argc, char **argv) {
Expand Down Expand Up @@ -738,6 +773,11 @@ int main(int argc, char **argv) {
dense_<f32c_k>("vdot_f32c_blas", vdot_f32c_blas, simsimd_vdot_f32c_accurate);
dense_<f64c_k>("vdot_f64c_blas", vdot_f64c_blas, simsimd_vdot_f64c_serial);

curved_<f64_k>("bilinear_f64_blas", bilinear_f64_blas, simsimd_bilinear_f64_serial);
curved_<f64c_k>("bilinear_f64c_blas", bilinear_f64c_blas, simsimd_bilinear_f64c_serial);
curved_<f32_k>("bilinear_f32_blas", bilinear_f32_blas, simsimd_bilinear_f32_accurate);
curved_<f32c_k>("bilinear_f32c_blas", bilinear_f32c_blas, simsimd_bilinear_f32c_accurate);

#endif

#if SIMSIMD_TARGET_NEON
Expand Down Expand Up @@ -995,6 +1035,8 @@ int main(int argc, char **argv) {
fma_<bf16_k>("fma_bf16_skylake", simsimd_fma_bf16_skylake, simsimd_fma_bf16_accurate, simsimd_l2_bf16_accurate);
fma_<bf16_k>("wsum_bf16_skylake", simsimd_wsum_bf16_skylake, simsimd_wsum_bf16_accurate, simsimd_l2_bf16_accurate);

curved_<f64_k>("bilinear_f64_skylake", simsimd_bilinear_f64_skylake, simsimd_bilinear_f64_serial);
curved_<f64c_k>("bilinear_f64c_skylake", simsimd_bilinear_f64c_skylake, simsimd_bilinear_f64c_serial);
#endif

sparse_<u16_k>("intersect_u16_serial", simsimd_intersect_u16_serial, simsimd_intersect_u16_accurate);
Expand All @@ -1003,12 +1045,16 @@ int main(int argc, char **argv) {
sparse_<u32_k>("intersect_u32_accurate", simsimd_intersect_u32_accurate, simsimd_intersect_u32_accurate);

curved_<f64_k>("bilinear_f64_serial", simsimd_bilinear_f64_serial, simsimd_bilinear_f64_serial);
curved_<f64c_k>("bilinear_f64c_serial", simsimd_bilinear_f64c_serial, simsimd_bilinear_f64c_serial);
curved_<f64_k>("mahalanobis_f64_serial", simsimd_mahalanobis_f64_serial, simsimd_mahalanobis_f64_serial);
curved_<f32_k>("bilinear_f32_serial", simsimd_bilinear_f32_serial, simsimd_bilinear_f32_accurate);
curved_<f32c_k>("bilinear_f32c_serial", simsimd_bilinear_f32c_serial, simsimd_bilinear_f32c_accurate);
curved_<f32_k>("mahalanobis_f32_serial", simsimd_mahalanobis_f32_serial, simsimd_mahalanobis_f32_accurate);
curved_<f16_k>("bilinear_f16_serial", simsimd_bilinear_f16_serial, simsimd_bilinear_f16_accurate);
curved_<f16c_k>("bilinear_f16c_serial", simsimd_bilinear_f16c_serial, simsimd_bilinear_f16c_accurate);
curved_<f16_k>("mahalanobis_f16_serial", simsimd_mahalanobis_f16_serial, simsimd_mahalanobis_f16_accurate);
curved_<bf16_k>("bilinear_bf16_serial", simsimd_bilinear_bf16_serial, simsimd_bilinear_bf16_accurate);
curved_<bf16c_k>("bilinear_bf16c_serial", simsimd_bilinear_bf16c_serial, simsimd_bilinear_bf16c_accurate);
curved_<bf16_k>("mahalanobis_bf16_serial", simsimd_mahalanobis_bf16_serial, simsimd_mahalanobis_bf16_accurate);

dense_<bf16_k>("dot_bf16_serial", simsimd_dot_bf16_serial, simsimd_dot_bf16_accurate);
Expand Down

0 comments on commit 1e5b9d7

Please sign in to comment.