Skip to content

Commit

Permalink
Add more unittest for src/common (#850)
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain authored Sep 19, 2024
1 parent eabb3e0 commit cda1401
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 47 deletions.
6 changes: 1 addition & 5 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,8 @@ class ThreadPool {
if (num_threads <= 0) {
return;
}
if (build_pool_ == nullptr) { // this should not happen in prod
omp_before = omp_get_max_threads();
} else {
omp_before = build_pool_->size();
}

omp_before = (build_pool_ ? build_pool_->size() : omp_get_max_threads());
#ifdef OPENBLAS_OS_LINUX
blas_thread_before = openblas_get_num_threads();
openblas_set_num_threads(num_threads);
Expand Down
2 changes: 1 addition & 1 deletion src/common/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg
auto value_str = json[it.first].get<std::string>();
CFG_INT::value_type v = std::stoi(value_str.c_str(), &sz);
if (sz < value_str.length()) {
throw KnowhereException(std::string("wrong data type in json ") + value_str);
KNOWHERE_THROW_MSG(std::string("wrong data type in json ") + value_str);
}
json[it.first] = v;
}
Expand Down
12 changes: 6 additions & 6 deletions src/common/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ initTelemetry(const TraceConfig& cfg) {
opts.endpoint = cfg.jaegerURL;
exporter = jaeger::JaegerExporterFactory::Create(opts);
LOG_KNOWHERE_INFO_ << "init jaeger exporter, endpoint: " << opts.endpoint;
} else if (cfg.exporter == "otlp") {
auto opts = otlp::OtlpGrpcExporterOptions{};
opts.endpoint = cfg.otlpEndpoint;
opts.use_ssl_credentials = cfg.oltpSecure;
exporter = otlp::OtlpGrpcExporterFactory::Create(opts);
LOG_KNOWHERE_INFO_ << "init otlp exporter, endpoint: " << opts.endpoint;
// } else if (cfg.exporter == "otlp") {
// auto opts = otlp::OtlpGrpcExporterOptions{};
// opts.endpoint = cfg.otlpEndpoint;
// opts.use_ssl_credentials = cfg.oltpSecure;
// exporter = otlp::OtlpGrpcExporterFactory::Create(opts);
// LOG_KNOWHERE_INFO_ << "init otlp exporter, endpoint: " << opts.endpoint;
} else {
LOG_KNOWHERE_INFO_ << "Empty Trace";
enable_trace = false;
Expand Down
4 changes: 1 addition & 3 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ NormalizeDataset(const DataSetPtr dataset) {

LOG_KNOWHERE_DEBUG_ << "vector normalize, rows " << rows << ", dim " << dim;

for (int32_t i = 0; i < rows; i++) {
NormalizeVec<DataType>(data + i * dim, dim);
}
NormalizeVecs<DataType>(data, rows, dim);
}

void
Expand Down
58 changes: 58 additions & 0 deletions tests/ut/test_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -897,3 +897,61 @@ TEST_CASE("Test config load", "[MATERIALIZED_VIEW_SEARCH_INFO]") {
CHECK(s == knowhere::Status::success);
}
}

TEST_CASE("Test config", "[FormatAndCheck]") {
knowhere::Status s;
std::string err_msg;

SECTION("check config with string type values") {
class TestConfig : public knowhere::Config {
public:
CFG_INT int_val;
CFG_FLOAT float_val;
CFG_BOOL true_val;
CFG_BOOL false_val;
KNOHWERE_DECLARE_CONFIG(TestConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(int_val).for_train_and_search();
KNOWHERE_CONFIG_DECLARE_FIELD(float_val).for_train_and_search();
KNOWHERE_CONFIG_DECLARE_FIELD(true_val).for_train_and_search();
KNOWHERE_CONFIG_DECLARE_FIELD(false_val).for_train_and_search();
}
};

TestConfig test_cfg;
knowhere::Json json;

json = knowhere::Json::parse(R"({
"int_val": "123",
"float_val": "1.23",
"true_val": "true",
"false_val": "false"
})");
s = knowhere::Config::FormatAndCheck(test_cfg, json, &err_msg);
CHECK(s == knowhere::Status::success);
s = knowhere::Config::Load(test_cfg, json, knowhere::SEARCH, &err_msg);
CHECK(s == knowhere::Status::success);
CHECK(test_cfg.int_val.value() == 123);
CHECK_LT(std::abs(test_cfg.float_val.value() - 1.23), 0.00001);
CHECK(test_cfg.true_val.value() == true);
CHECK(test_cfg.false_val.value() == false);
}

SECTION("check config with invalid string type int value") {
class TestConfig : public knowhere::Config {
public:
CFG_INT int_val;
KNOHWERE_DECLARE_CONFIG(TestConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(int_val).for_train_and_search();
}
};

TestConfig test_cfg;
knowhere::Json json;

json = knowhere::Json::parse(R"({
"int_val": "12.3"
})");
s = knowhere::Config::FormatAndCheck(test_cfg, json, &err_msg);
CHECK(s == knowhere::Status::invalid_value_in_json);
}
}
43 changes: 30 additions & 13 deletions tests/ut/test_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License

#include <unistd.h>

#include <memory>

#include "catch2/catch_approx.hpp"
Expand All @@ -22,20 +24,35 @@ using namespace opentelemetry::trace;
TEST_CASE("Test Tracer init", "Init test") {
using Catch::Approx;

auto config = std::make_shared<TraceConfig>();
config->exporter = "stdout";
config->nodeID = 1;
initTelemetry(*config);
auto span = StartSpan("test");
REQUIRE(span->IsRecording());
SECTION("check stdout") {
auto config = std::make_shared<TraceConfig>();
config->exporter = "stdout";
config->nodeID = 1;
initTelemetry(*config);
auto span = StartSpan("test");
REQUIRE(span->IsRecording());

SetRootSpan(span);
AddEvent("sleep");
usleep(20000);
CloseRootSpan();
}

SECTION("check jaeger") {
auto config = std::make_shared<TraceConfig>();
config->exporter = "jaeger";
// use default jaeger collector port for test
config->jaegerURL = "http://localhost:14268/api/traces";
config->nodeID = 1;
initTelemetry(*config);
auto span = StartSpan("test");
REQUIRE(span->IsRecording());

config = std::make_shared<TraceConfig>();
config->exporter = "jaeger";
config->jaegerURL = "http://localhost:14268/api/traces";
config->nodeID = 1;
initTelemetry(*config);
span = StartSpan("test");
REQUIRE(span->IsRecording());
SetRootSpan(span);
AddEvent("sleep");
usleep(20000);
CloseRootSpan();
}
}

TEST_CASE("Test Tracer span", "Span test") {
Expand Down
73 changes: 54 additions & 19 deletions tests/ut/test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,60 @@ namespace {
const std::vector<size_t> kBitsetSizes{4, 8, 10, 64, 100, 500, 1024};
}

template <typename T>
void
CheckNormalizeDataset(int rows, int dim, float diff) {
auto ds = GenDataSet(rows, dim);
auto type_ds = knowhere::ConvertToDataTypeIfNeeded<T>(ds);
auto data = (T*)type_ds->GetTensor();

knowhere::NormalizeDataset<T>(type_ds);

for (int i = 0; i < rows; ++i) {
float sum = 0.0;
for (int j = 0; j < dim; ++j) {
auto val = data[i * dim + j];
sum += val * val;
}
CHECK(std::abs(1.0f - sum) <= diff);
}
}

template <typename T>
void
CheckCopyAndNormalizeVecs(int rows, int dim, float diff) {
auto ds = GenDataSet(rows, dim);
auto type_ds = knowhere::ConvertToDataTypeIfNeeded<T>(ds);
auto data = (T*)type_ds->GetTensor();

auto data_copy = knowhere::CopyAndNormalizeVecs<T>(data, rows, dim);

for (int i = 0; i < rows; ++i) {
float sum = 0.0;
for (int j = 0; j < dim; ++j) {
auto val = data_copy[i * dim + j];
sum += val * val;
}
CHECK(std::abs(1.0f - sum) <= diff);
}
}

TEST_CASE("Test Vector Normalization", "[normalize]") {
using Catch::Approx;

const float floatDiff = 0.00001;
uint64_t nb = 1000000;
uint64_t rows = 100;
uint64_t dim = 128;
int64_t seed = 42;

SECTION("Test normalize") {
auto train_ds = GenDataSet(nb, dim, seed);
auto data = (float*)train_ds->GetTensor();

knowhere::NormalizeDataset<knowhere::fp32>(train_ds);
SECTION("Test Normalize Dataset") {
CheckNormalizeDataset<knowhere::fp32>(rows, dim, 0.00001);
CheckNormalizeDataset<knowhere::fp16>(rows, dim, 0.001);
CheckNormalizeDataset<knowhere::bf16>(rows, dim, 0.01);
}

for (size_t i = 0; i < nb; ++i) {
float sum = 0.0;
for (size_t j = 0; j < dim; ++j) {
auto val = data[i * dim + j];
sum += val * val;
}
CHECK(std::abs(1.0f - sum) <= floatDiff);
}
SECTION("Test Copy and Normalize Vectors") {
CheckCopyAndNormalizeVecs<knowhere::fp32>(rows, dim, 0.00001);
CheckCopyAndNormalizeVecs<knowhere::fp16>(rows, dim, 0.001);
CheckCopyAndNormalizeVecs<knowhere::bf16>(rows, dim, 0.01);
}
}

Expand Down Expand Up @@ -175,10 +207,13 @@ TEST_CASE("Test ThreadPool") {
SECTION("ScopedOmpSetter") {
int prev_num_threads = omp_get_max_threads();
{
knowhere::ThreadPool::ScopedOmpSetter setter(2 * prev_num_threads);
REQUIRE(omp_get_max_threads() == 2 * prev_num_threads);
int target_num_threads = (prev_num_threads / 2) > 0 ? (prev_num_threads / 2) : 1;
knowhere::ThreadPool::ScopedOmpSetter setter(target_num_threads);
auto thread_num = omp_get_max_threads();
REQUIRE(thread_num == target_num_threads);
#ifdef OPENBLAS_OS_LINUX
REQUIRE(openblas_get_num_threads() == 2 * prev_num_threads);
auto openblas_thread_num = openblas_get_num_threads();
REQUIRE(openblas_thread_num == target_num_threads);
#endif
}
}
Expand Down

0 comments on commit cda1401

Please sign in to comment.