From 4c88001ebef69265160f070448848f5188c9a132 Mon Sep 17 00:00:00 2001 From: Alexandr-Solovev Date: Thu, 21 Nov 2024 05:26:51 -0800 Subject: [PATCH] add eigen lib --- WORKSPACE | 8 +++ .../primitives/lapack/test/syevd_dpc.cpp | 51 +++++++++++++++++-- cpp/oneapi/dal/test/engine/BUILD | 1 + cpp/oneapi/dal/test/engine/common.hpp | 3 ++ dev/bazel/deps/boost.tpl.BUILD | 5 +- dev/bazel/deps/eigen.tpl.BUILD | 8 +++ 6 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 dev/bazel/deps/eigen.tpl.BUILD diff --git a/WORKSPACE b/WORKSPACE index 64fc4f0eed3..2597b8765c9 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -121,6 +121,14 @@ http_archive( strip_prefix = "Catch2-3.7.1", ) +http_archive( + name = "eigen", + url = "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz", + sha256 = "8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72", + build_file = "@onedal//dev/bazel/deps:eigen.tpl.BUILD", + strip_prefix = "eigen-3.4.0", +) + http_archive( name = "fmt", url = "https://github.com/fmtlib/fmt/archive/11.0.2.tar.gz", diff --git a/cpp/oneapi/dal/backend/primitives/lapack/test/syevd_dpc.cpp b/cpp/oneapi/dal/backend/primitives/lapack/test/syevd_dpc.cpp index 56484014a81..212572f511f 100644 --- a/cpp/oneapi/dal/backend/primitives/lapack/test/syevd_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/lapack/test/syevd_dpc.cpp @@ -128,6 +128,46 @@ class syevd_test : public te::float_algo_fixture { } } + void check_eigvals_with_eigen(const la::matrix& s, + const la::matrix& eigvecs, + const la::matrix& eigvals) const { + INFO("convert results to float64"); + const auto s_f64 = la::astype(s); + const auto eigvals_f64 = la::astype(eigvals); + const auto eigvecs_f64 = la::astype(eigvecs); + std::int64_t row_count = s.get_row_count(); + std::int64_t column_count = s.get_column_count(); + const Float* data = s.get_data(); + + Eigen::Matrix eigen_matrix(row_count, column_count); + for (int i = 0; i < eigen_matrix.rows(); ++i) { + for (int j = 0; j < eigen_matrix.cols(); ++j) { + eigen_matrix(i, j) = data[i * column_count + j]; + } + } + + Eigen::SelfAdjointEigenSolver> es( + eigen_matrix); + + auto eigenvalues = es.eigenvalues().real(); + INFO("oneDAL eigvals vs Eigen eigvals"); + la::enumerate_linear(eigvals_f64, [&](std::int64_t i, Float x) { + REQUIRE(abs(eigvals_f64.get(i) - eigenvalues(i)) < 0.1); + }); + + INFO("oneDAL eigvectors vs Eigen eigvectors"); + auto eigenvectors = es.eigenvectors().real(); + + const double* eigenvec_ptr = eigvecs_f64.get_data(); + //TODO: investigate Eigen classes and align checking between oneDAL and Eigen classes. + for (int j = 0; j < eigvecs.get_column_count(); ++j) { + auto column_eigen = eigenvectors.col(j); + for (int i = 0; i < eigvecs.get_row_count(); ++i) { + REQUIRE((abs(eigenvec_ptr[j * row_count + i]) - abs(column_eigen(i))) < 0.1); + } + } + } + void check_eigvals_are_ascending(const la::matrix& eigvals) const { INFO("check eigenvalues order is ascending"); la::enumerate_linear(eigvals, [&](std::int64_t i, Float x) { @@ -158,14 +198,15 @@ TEMPLATE_LIST_TEST_M(syevd_test, "test syevd with pos def matrix", "[sym_eigvals this->check_eigvals_definition(s, eigenvectors, eigenvalues); this->check_eigvals_are_ascending(eigenvalues); + this->check_eigvals_with_eigen(s, eigenvectors, eigenvalues); } -TEMPLATE_LIST_TEST_M(syevd_test, "test syevd with pos def matrix 2", "[sym_eigvals]", eigen_types) { - const auto s = this->generate_symmetric_positive(); +// TEMPLATE_LIST_TEST_M(syevd_test, "test syevd with pos def matrix 2", "[sym_eigvals]", eigen_types) { +// const auto s = this->generate_symmetric_positive(); - const auto [eigenvectors, eigenvalues] = this->call_sym_eigvals_inplace_descending(s); +// const auto [eigenvectors, eigenvalues] = this->call_sym_eigvals_inplace_descending(s); - this->check_eigvals_are_descending(eigenvalues); -} +// this->check_eigvals_are_descending(eigenvalues); +// } } // namespace oneapi::dal::backend::primitives::test diff --git a/cpp/oneapi/dal/test/engine/BUILD b/cpp/oneapi/dal/test/engine/BUILD index fd1cacfc618..5732edd02a2 100644 --- a/cpp/oneapi/dal/test/engine/BUILD +++ b/cpp/oneapi/dal/test/engine/BUILD @@ -24,6 +24,7 @@ dal_test_module( extra_deps = [ "@boost//:boost", "@catch2//:catch2", + "@eigen//:eigen", "@fmt//:fmt", ], ) diff --git a/cpp/oneapi/dal/test/engine/common.hpp b/cpp/oneapi/dal/test/engine/common.hpp index 542b364e886..765aadfc817 100644 --- a/cpp/oneapi/dal/test/engine/common.hpp +++ b/cpp/oneapi/dal/test/engine/common.hpp @@ -25,6 +25,9 @@ //Necessary headers from boost #include +#include +#include + #include "oneapi/dal/train.hpp" #include "oneapi/dal/infer.hpp" #include "oneapi/dal/compute.hpp" diff --git a/dev/bazel/deps/boost.tpl.BUILD b/dev/bazel/deps/boost.tpl.BUILD index ae79787b5d3..e6d5c32872c 100644 --- a/dev/bazel/deps/boost.tpl.BUILD +++ b/dev/bazel/deps/boost.tpl.BUILD @@ -3,8 +3,7 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "boost", srcs = glob([ - "boost/libs/serialization/src/**/*.cpp", - "boost/libs/libboost*.a", + "libs/libboost*.a", ]), hdrs = glob([ "boost/**/*.h", @@ -12,7 +11,7 @@ cc_library( "boost/**/*.ipp", ]), includes = [ - "boost", + ".", ], visibility = ["//visibility:public"], ) diff --git a/dev/bazel/deps/eigen.tpl.BUILD b/dev/bazel/deps/eigen.tpl.BUILD new file mode 100644 index 00000000000..e4d892ecfcb --- /dev/null +++ b/dev/bazel/deps/eigen.tpl.BUILD @@ -0,0 +1,8 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "eigen", + hdrs = glob(["Eigen/**"]), + includes = [""], + visibility = ["//visibility:public"], +)