From 770cfdd0412f352f336e0bd98d1f040f8da9b00c Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Thu, 26 Sep 2024 10:51:08 +0800 Subject: [PATCH] Update test_simd (#867) Signed-off-by: Cai Yudong --- src/simd/hook.cc | 29 ++-- tests/ut/test_simd.cc | 356 +++++++++++++++++++++--------------------- 2 files changed, 200 insertions(+), 185 deletions(-) diff --git a/src/simd/hook.cc b/src/simd/hook.cc index 81e0c7e2d..cf4cd8707 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -185,11 +185,13 @@ fvec_hook(std::string& simd_type) { ivec_L2sqr = ivec_L2sqr_avx512; fp16_vec_inner_product = fp16_vec_inner_product_avx512; - bf16_vec_inner_product = bf16_vec_inner_product_avx512; fp16_vec_L2sqr = fp16_vec_L2sqr_avx512; - bf16_vec_L2sqr = bf16_vec_L2sqr_avx512; fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_avx512; + + bf16_vec_inner_product = bf16_vec_inner_product_avx512; + bf16_vec_L2sqr = bf16_vec_L2sqr_avx512; bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_avx512; + simd_type = "AVX512"; support_pq_fast_scan = true; } else if (use_avx2 && cpu_support_avx2()) { @@ -211,10 +213,11 @@ fvec_hook(std::string& simd_type) { ivec_L2sqr = ivec_L2sqr_avx; fp16_vec_inner_product = fp16_vec_inner_product_avx; - bf16_vec_inner_product = bf16_vec_inner_product_avx; fp16_vec_L2sqr = fp16_vec_L2sqr_avx; - bf16_vec_L2sqr = bf16_vec_L2sqr_avx; fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_avx; + + bf16_vec_inner_product = bf16_vec_inner_product_avx; + bf16_vec_L2sqr = bf16_vec_L2sqr_avx; bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_avx; simd_type = "AVX2"; @@ -236,14 +239,15 @@ fvec_hook(std::string& simd_type) { ivec_inner_product = ivec_inner_product_sse; ivec_L2sqr = ivec_L2sqr_sse; - bf16_vec_inner_product = bf16_vec_inner_product_sse; - bf16_vec_L2sqr = bf16_vec_L2sqr_sse; - bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_sse; fp16_vec_inner_product = fp16_vec_inner_product_ref; fp16_vec_L2sqr = fp16_vec_L2sqr_ref; fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_ref; + bf16_vec_inner_product = bf16_vec_inner_product_sse; + bf16_vec_L2sqr = bf16_vec_L2sqr_sse; + bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_sse; + simd_type = "SSE4_2"; support_pq_fast_scan = false; } else { @@ -265,10 +269,11 @@ fvec_hook(std::string& simd_type) { ivec_L2sqr = ivec_L2sqr_ref; fp16_vec_inner_product = fp16_vec_inner_product_ref; - bf16_vec_inner_product = bf16_vec_inner_product_ref; fp16_vec_L2sqr = fp16_vec_L2sqr_ref; - bf16_vec_L2sqr = bf16_vec_L2sqr_ref; fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_ref; + + bf16_vec_inner_product = bf16_vec_inner_product_ref; + bf16_vec_L2sqr = bf16_vec_L2sqr_ref; bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_ref; simd_type = "GENERIC"; @@ -290,11 +295,13 @@ fvec_hook(std::string& simd_type) { ivec_inner_product = ivec_inner_product_neon; ivec_L2sqr = ivec_L2sqr_neon; + fp16_vec_inner_product = fp16_vec_inner_product_neon; - bf16_vec_inner_product = bf16_vec_inner_product_neon; fp16_vec_L2sqr = fp16_vec_L2sqr_neon; - bf16_vec_L2sqr = bf16_vec_L2sqr_neon; fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_neon; + + bf16_vec_inner_product = bf16_vec_inner_product_neon; + bf16_vec_L2sqr = bf16_vec_L2sqr_neon; bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_neon; simd_type = "NEON"; diff --git a/tests/ut/test_simd.cc b/tests/ut/test_simd.cc index 5e15227f4..a6bcd809c 100644 --- a/tests/ut/test_simd.cc +++ b/tests/ut/test_simd.cc @@ -13,203 +13,211 @@ #include "catch2/catch_test_macros.hpp" #include "catch2/generators/catch_generators.hpp" #include "catch2/matchers/catch_matchers_floating_point.hpp" -#include "knowhere/comp/brute_force.h" -#include "knowhere/comp/index_param.h" #include "knowhere/comp/knowhere_config.h" #include "simd/distances_ref.h" #include "simd/hook.h" -#if defined(__x86_64__) -#include "simd/distances_avx.h" -#include "simd/distances_avx512.h" -#include "simd/distances_sse.h" -#endif - -#if defined(__ARM_NEON) -#include "simd/distances_neon.h" -#endif - #include "utils.h" + template std::unique_ptr -GenRandomVector(int dim, int seed = 42) { +GenRandomVector(int dim, int rows, int seed) { std::mt19937 rng(seed); - std::uniform_real_distribution<> distrib(0.0, 100.0); - auto x = std::make_unique(dim); - for (int i = 0; i < dim; ++i) x[i] = (DataType)distrib(rng); + std::uniform_real_distribution<> distrib(-10.0, 10.0); + auto x = std::make_unique(rows * dim); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < dim; j++) { + x[i * dim + j] = (DataType)distrib(rng); + } + } return x; } -TEST_CASE("Test BruteForce Search SIMD", "[bf]") { - using Catch::Approx; - - const int64_t nb = 1000; - const int64_t nq = 10; - const int64_t dim = 127; - const int64_t k = 5; - - auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); - - const auto train_ds = GenDataSet(nb, dim); - const auto query_ds = CopyDataSet(train_ds, nq); - - knowhere::Json conf = { - {knowhere::meta::DIM, dim}, - {knowhere::meta::METRIC_TYPE, metric}, - {knowhere::meta::TOPK, k}, - }; - - auto test_search_with_simd = [&](knowhere::KnowhereConfig::SimdType simd_type) { - knowhere::KnowhereConfig::SetSimdType(simd_type); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); - REQUIRE(gt.has_value()); - auto gt_ids = gt.value()->GetIds(); - auto gt_dist = gt.value()->GetDistance(); - - for (int64_t i = 0; i < nq; i++) { - REQUIRE(gt_ids[i * k] == i); - if (metric == knowhere::metric::L2) { - REQUIRE(gt_dist[i * k] == 0); - } else { - REQUIRE(std::abs(gt_dist[i * k] - 1.0) < 0.00001); - } +template +std::unique_ptr +ConvertVector(float* data, int dim, int rows) { + auto x = std::make_unique(rows * dim); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < dim; j++) { + x[i * dim + j] = (DataType)data[i * dim + j]; } - }; - - for (auto simd_type : {knowhere::KnowhereConfig::SimdType::AVX512, knowhere::KnowhereConfig::SimdType::AVX2, - knowhere::KnowhereConfig::SimdType::SSE4_2, knowhere::KnowhereConfig::SimdType::GENERIC, - knowhere::KnowhereConfig::SimdType::AUTO}) { - test_search_with_simd(simd_type); } + return x; } -TEST_CASE("Test PQ Search SIMD", "[pq]") { +TEST_CASE("Test distance") { using Catch::Approx; + auto simd_type = GENERATE(as{}, knowhere::KnowhereConfig::SimdType::AVX512, + knowhere::KnowhereConfig::SimdType::AVX2, knowhere::KnowhereConfig::SimdType::SSE4_2, + knowhere::KnowhereConfig::SimdType::GENERIC, knowhere::KnowhereConfig::SimdType::AUTO); + auto dim = GENERATE(as{}, 1, 7, 14, 21, 28, 35, 42, 49, 56, 64, 128, 256, 512); + + LOG_KNOWHERE_INFO_ << "simd type: " << simd_type << ", dim: " << dim; + knowhere::KnowhereConfig::SetSimdType(simd_type); + + SECTION("test single distance calculation") { + const size_t nx = 1, ny = 1; + + const float tolerance = 0.000001f; + const auto x = GenRandomVector(dim, nx, 314); + const auto y = GenRandomVector(dim, ny, 271); + + // fp16's accuracy is 0.001, during calculation we let the tolerance be 0.002 + const float fp16_tolerance = 0.002f; + const auto x_fp16 = ConvertVector(x.get(), nx, dim); + const auto y_fp16 = ConvertVector(y.get(), ny, dim); + + // bf16's accuracy is 0.01, during calculation we let the tolerance be 0.02 + const float bf16_tolerance = 0.02f; + const auto x_bf16 = ConvertVector(x.get(), nx, dim); + const auto y_bf16 = ConvertVector(y.get(), ny, dim); + + // int8 + const auto xi = ConvertVector(x.get(), nx, dim); + const auto yi = ConvertVector(y.get(), ny, dim); + + const auto ref_ip = faiss::fvec_inner_product_ref(x.get(), y.get(), dim); + const auto ref_L2sqr = faiss::fvec_L2sqr_ref(x.get(), y.get(), dim); + const auto ref_L1 = faiss::fvec_L1_ref(x.get(), y.get(), dim); + const auto ref_Linf = faiss::fvec_Linf_ref(x.get(), y.get(), dim); + const auto ref_norm_L2sqr = faiss::fvec_norm_L2sqr_ref(x.get(), dim); + + const auto ref_i_ip = faiss::ivec_inner_product_ref(xi.get(), yi.get(), dim); + const auto ref_i_L2sqr = faiss::ivec_L2sqr_ref(xi.get(), yi.get(), dim); + + // float + REQUIRE_THAT(faiss::fvec_inner_product(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_ip, tolerance)); + REQUIRE_THAT(faiss::fvec_L2sqr(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_L2sqr, tolerance)); + REQUIRE_THAT(faiss::fvec_L1(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_L1, tolerance)); + REQUIRE_THAT(faiss::fvec_Linf(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_Linf, tolerance)); + REQUIRE_THAT(faiss::fvec_norm_L2sqr(x.get(), dim), Catch::Matchers::WithinRel(ref_norm_L2sqr, tolerance)); + + // fp16 + REQUIRE_THAT(faiss::fp16_vec_inner_product(x_fp16.get(), y_fp16.get(), dim), + Catch::Matchers::WithinRel(ref_ip, fp16_tolerance)); + REQUIRE_THAT(faiss::fp16_vec_L2sqr(x_fp16.get(), y_fp16.get(), dim), + Catch::Matchers::WithinRel(ref_L2sqr, fp16_tolerance)); + REQUIRE_THAT(faiss::fp16_vec_norm_L2sqr(x_fp16.get(), dim), + Catch::Matchers::WithinRel(ref_norm_L2sqr, fp16_tolerance)); + + // bf16 + REQUIRE_THAT(faiss::bf16_vec_inner_product(x_bf16.get(), y_bf16.get(), dim), + Catch::Matchers::WithinRel(ref_ip, bf16_tolerance)); + REQUIRE_THAT(faiss::bf16_vec_L2sqr(x_bf16.get(), y_bf16.get(), dim), + Catch::Matchers::WithinRel(ref_L2sqr, bf16_tolerance)); + REQUIRE_THAT(faiss::bf16_vec_norm_L2sqr(x_bf16.get(), dim), + Catch::Matchers::WithinRel(ref_norm_L2sqr, bf16_tolerance)); + + // int8 + CHECK_EQ(faiss::ivec_inner_product(xi.get(), yi.get(), dim), ref_i_ip); + CHECK_EQ(faiss::ivec_L2sqr(xi.get(), yi.get(), dim), ref_i_L2sqr); + } + + SECTION("test ny distance calculation") { + const size_t nx = 1, ny = 10; + + const float tolerance = 0.0001f; + const auto x = GenRandomVector(dim, nx, 314); + const auto y = GenRandomVector(dim, ny, 271); + + auto ref_ip = std::make_unique(ny); + faiss::fvec_inner_products_ny_ref(ref_ip.get(), x.get(), y.get(), dim, ny); + auto ref_l2 = std::make_unique(ny); + faiss::fvec_L2sqr_ny_ref(ref_l2.get(), x.get(), y.get(), dim, ny); - const int64_t nb = 1000; - const int64_t nq = 10; - const int64_t dim = 128; - const int64_t k = 5; - - auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); - auto version = GenTestVersionList(); - - const auto train_ds = GenDataSet(nb, dim); - const auto query_ds = CopyDataSet(train_ds, nq); - - knowhere::Json conf = { - {knowhere::meta::DIM, dim}, {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, k}, - {knowhere::indexparam::NLIST, 16}, {knowhere::indexparam::NPROBE, 8}, {knowhere::indexparam::NBITS, 8}, - }; - - auto test_search_with_simd = [&](const int64_t m, knowhere::KnowhereConfig::SimdType simd_type) { - conf[knowhere::indexparam::M] = m; - - knowhere::KnowhereConfig::SetSimdType(simd_type); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); - REQUIRE(gt.has_value()); - auto gt_ids = gt.value()->GetIds(); - auto gt_dist = gt.value()->GetDistance(); - - for (int64_t i = 0; i < nq; i++) { - REQUIRE(gt_ids[i * k] == i); - if (metric == knowhere::metric::L2) { - REQUIRE(gt_dist[i * k] == 0); - } else { - REQUIRE(std::abs(gt_dist[i * k] - 1.0) < 0.00001); - } + auto dis = std::make_unique(ny); + + faiss::fvec_inner_products_ny(dis.get(), x.get(), y.get(), dim, ny); + for (size_t i = 0; i < ny; i++) { + REQUIRE_THAT(dis[i], Catch::Matchers::WithinRel(ref_ip[i], tolerance)); } - auto idx = knowhere::IndexFactory::Instance() - .Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, version) - .value(); - REQUIRE(idx.Build(train_ds, conf) == knowhere::Status::success); - auto res = idx.Search(query_ds, conf, nullptr); - REQUIRE(res.has_value()); - float recall = GetKNNRecall(*gt.value(), *res.value()); - REQUIRE(recall > 0.2); - }; - - for (auto simd_type : {knowhere::KnowhereConfig::SimdType::GENERIC, knowhere::KnowhereConfig::SimdType::AUTO}) { - for (int64_t m : {8, 16, 32, 64, 128}) { - test_search_with_simd(m, simd_type); + faiss::fvec_L2sqr_ny(dis.get(), x.get(), y.get(), dim, ny); + for (size_t i = 0; i < ny; i++) { + REQUIRE_THAT(dis[i], Catch::Matchers::WithinRel(ref_l2[i], tolerance)); } } -} -TEST_CASE("Test fp16 distance", "[fp16]") { - using Catch::Approx; - auto dim = GENERATE(as{}, 1, 2, 10, 69, 128, 141, 510, 1024); - - auto x = GenRandomVector(dim, 11); - auto y = GenRandomVector(dim, 22); - auto ref_l2_dist = faiss::fp16_vec_L2sqr_ref(x.get(), y.get(), dim); - auto ref_ip_dist = faiss::fp16_vec_inner_product_ref(x.get(), y.get(), dim); - auto ref_norm_l2_dist = faiss::fp16_vec_norm_L2sqr_ref(x.get(), dim); -#if defined(__ARM_NEON) - // neon - REQUIRE_THAT(faiss::fp16_vec_L2sqr_neon(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_l2_dist, 0.001f)); - REQUIRE_THAT(faiss::fp16_vec_inner_product_neon(x.get(), y.get(), dim), - Catch::Matchers::WithinRel(ref_ip_dist, 0.001f)); - REQUIRE_THAT(faiss::fp16_vec_norm_L2sqr_neon(x.get(), dim), Catch::Matchers::WithinRel(ref_norm_l2_dist, 0.001f)); -#endif -#if defined(__x86_64__) - if (faiss::cpu_support_avx2()) { - REQUIRE_THAT(faiss::fp16_vec_L2sqr_avx(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_l2_dist, 0.001f)); - REQUIRE_THAT(faiss::fp16_vec_inner_product_avx(x.get(), y.get(), dim), - Catch::Matchers::WithinRel(ref_ip_dist, 0.001f)); - REQUIRE_THAT(faiss::fp16_vec_norm_L2sqr_avx(x.get(), dim), - Catch::Matchers::WithinRel(ref_norm_l2_dist, 0.001f)); - } - if (faiss::cpu_support_avx512()) { - REQUIRE_THAT(faiss::fp16_vec_L2sqr_avx512(x.get(), y.get(), dim), - Catch::Matchers::WithinRel(ref_l2_dist, 0.001f)); - REQUIRE_THAT(faiss::fp16_vec_inner_product_avx512(x.get(), y.get(), dim), - Catch::Matchers::WithinRel(ref_ip_dist, 0.001f)); - REQUIRE_THAT(faiss::fp16_vec_norm_L2sqr_avx512(x.get(), dim), - Catch::Matchers::WithinRel(ref_norm_l2_dist, 0.001f)); - } -#endif -} + SECTION("test madd distance calculation") { + const size_t n = 1; -TEST_CASE("Test bf16 distance", "[bf16]") { - using Catch::Approx; + const float tolerance = 0.0001f; + const auto a = GenRandomVector(dim, n, 314); + const auto b = GenRandomVector(dim, n, 271); + const float bf = 3.14159; - auto dim = GENERATE(as{}, 1, 2, 10, 69, 128, 141, 510, 1024); - - auto x = GenRandomVector(dim, 11); - auto y = GenRandomVector(dim, 22); - auto ref_l2_dist = faiss::bf16_vec_L2sqr_ref(x.get(), y.get(), dim); - auto ref_ip_dist = faiss::bf16_vec_inner_product_ref(x.get(), y.get(), dim); - auto ref_norm_l2_dist = faiss::bf16_vec_norm_L2sqr_ref(x.get(), dim); -#if defined(__ARM_NEON) - // neon - REQUIRE_THAT(faiss::bf16_vec_L2sqr_neon(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_l2_dist, 0.001f)); - REQUIRE_THAT(faiss::bf16_vec_inner_product_neon(x.get(), y.get(), dim), - Catch::Matchers::WithinRel(ref_ip_dist, 0.001f)); - REQUIRE_THAT(faiss::bf16_vec_norm_L2sqr_neon(x.get(), dim), Catch::Matchers::WithinRel(ref_norm_l2_dist, 0.001f)); -#endif -#if defined(__x86_64__) - if (faiss::cpu_support_sse4_2()) { - REQUIRE_THAT(faiss::bf16_vec_L2sqr_sse(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_l2_dist, 0.001f)); - REQUIRE_THAT(faiss::bf16_vec_inner_product_sse(x.get(), y.get(), dim), - Catch::Matchers::WithinRel(ref_ip_dist, 0.001f)); - REQUIRE_THAT(faiss::bf16_vec_norm_L2sqr_sse(x.get(), dim), - Catch::Matchers::WithinRel(ref_norm_l2_dist, 0.001f)); - } - if (faiss::cpu_support_avx2()) { - REQUIRE_THAT(faiss::bf16_vec_L2sqr_avx(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_l2_dist, 0.001f)); - REQUIRE_THAT(faiss::bf16_vec_inner_product_avx(x.get(), y.get(), dim), - Catch::Matchers::WithinRel(ref_ip_dist, 0.001f)); - REQUIRE_THAT(faiss::bf16_vec_norm_L2sqr_avx(x.get(), dim), - Catch::Matchers::WithinRel(ref_norm_l2_dist, 0.001f)); + auto ref_madd = std::make_unique(dim); + faiss::fvec_madd_ref(dim, a.get(), bf, b.get(), ref_madd.get()); + auto ref_madd_and_argmin = std::make_unique(dim); + faiss::fvec_madd_and_argmin_ref(dim, a.get(), bf, b.get(), ref_madd_and_argmin.get()); + + auto dis = std::make_unique(dim); + + faiss::fvec_madd(dim, a.get(), bf, b.get(), dis.get()); + for (size_t i = 0; i < dim; i++) { + REQUIRE_THAT(dis[i], Catch::Matchers::WithinRel(ref_madd[i], tolerance)); + } + + faiss::fvec_madd_and_argmin(dim, a.get(), bf, b.get(), dis.get()); + for (size_t i = 0; i < dim; i++) { + REQUIRE_THAT(dis[i], Catch::Matchers::WithinRel(ref_madd_and_argmin[i], tolerance)); + } } - if (faiss::cpu_support_avx512()) { - REQUIRE_THAT(faiss::bf16_vec_L2sqr_avx512(x.get(), y.get(), dim), - Catch::Matchers::WithinRel(ref_l2_dist, 0.001f)); - REQUIRE_THAT(faiss::bf16_vec_inner_product_avx512(x.get(), y.get(), dim), - Catch::Matchers::WithinRel(ref_ip_dist, 0.001f)); - REQUIRE_THAT(faiss::bf16_vec_norm_L2sqr_avx512(x.get(), dim), - Catch::Matchers::WithinRel(ref_norm_l2_dist, 0.001f)); + + SECTION("test batch_4 distance calculation") { + const size_t nx = 1, ny = 1; + + float tolerance = 0.00001f; + const auto x = GenRandomVector(dim, nx, 314); + const auto y = GenRandomVector(dim, ny, 271); + + const auto ref_ip = faiss::fvec_inner_product_ref(x.get(), y.get(), dim); + const auto ref_L2sqr = faiss::fvec_L2sqr_ref(x.get(), y.get(), dim); + + float batch_tolerance = 0.0002f; + const auto y0 = GenRandomVector(dim, ny, 271); + const auto y1 = GenRandomVector(dim, ny, 272); + const auto y2 = GenRandomVector(dim, ny, 273); + const auto y3 = GenRandomVector(dim, ny, 274); + + float ref_ip_0, ref_ip_1, ref_ip_2, ref_ip_3; + faiss::fvec_inner_product_batch_4_ref(x.get(), y0.get(), y1.get(), y2.get(), y3.get(), dim, ref_ip_0, ref_ip_1, + ref_ip_2, ref_ip_3); + float ref_l2_0, ref_l2_1, ref_l2_2, ref_l2_3; + faiss::fvec_L2sqr_batch_4_ref(x.get(), y0.get(), y1.get(), y2.get(), y3.get(), dim, ref_l2_0, ref_l2_1, + ref_l2_2, ref_l2_3); + + auto run_test = [&]() { + // float + REQUIRE_THAT(faiss::fvec_inner_product(x.get(), y.get(), dim), + Catch::Matchers::WithinRel(ref_ip, tolerance)); + REQUIRE_THAT(faiss::fvec_L2sqr(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_L2sqr, tolerance)); + + // batch + float dis0, dis1, dis2, dis3; + + faiss::fvec_inner_product_batch_4(x.get(), y0.get(), y1.get(), y2.get(), y3.get(), dim, dis0, dis1, dis2, + dis3); + REQUIRE_THAT(dis0, Catch::Matchers::WithinRel(ref_ip_0, batch_tolerance)); + REQUIRE_THAT(dis1, Catch::Matchers::WithinRel(ref_ip_1, batch_tolerance)); + REQUIRE_THAT(dis2, Catch::Matchers::WithinRel(ref_ip_2, batch_tolerance)); + REQUIRE_THAT(dis3, Catch::Matchers::WithinRel(ref_ip_3, batch_tolerance)); + + faiss::fvec_L2sqr_batch_4(x.get(), y0.get(), y1.get(), y2.get(), y3.get(), dim, dis0, dis1, dis2, dis3); + REQUIRE_THAT(dis0, Catch::Matchers::WithinRel(ref_l2_0, batch_tolerance)); + REQUIRE_THAT(dis1, Catch::Matchers::WithinRel(ref_l2_1, batch_tolerance)); + REQUIRE_THAT(dis2, Catch::Matchers::WithinRel(ref_l2_2, batch_tolerance)); + REQUIRE_THAT(dis3, Catch::Matchers::WithinRel(ref_l2_3, batch_tolerance)); + }; + + tolerance = 0.02f; + batch_tolerance = 0.05f; + knowhere::KnowhereConfig::EnablePatchForComputeFP32AsBF16(); + // TODO caiyd: need enable this test + // run_test(); + + tolerance = 0.00001f; + batch_tolerance = 0.0002f; + knowhere::KnowhereConfig::DisablePatchForComputeFP32AsBF16(); + run_test(); } -#endif }