Skip to content

Commit

Permalink
add eigen lib
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandr-Solovev committed Nov 21, 2024
1 parent 2dc78c7 commit 4c88001
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 8 deletions.
8 changes: 8 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ http_archive(
strip_prefix = "Catch2-3.7.1",
)

http_archive(
name = "eigen",
url = "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz",
sha256 = "8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72",
build_file = "@onedal//dev/bazel/deps:eigen.tpl.BUILD",
strip_prefix = "eigen-3.4.0",
)

http_archive(
name = "fmt",
url = "https://github.com/fmtlib/fmt/archive/11.0.2.tar.gz",
Expand Down
51 changes: 46 additions & 5 deletions cpp/oneapi/dal/backend/primitives/lapack/test/syevd_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,46 @@ class syevd_test : public te::float_algo_fixture<Float> {
}
}

void check_eigvals_with_eigen(const la::matrix<Float>& s,
const la::matrix<Float>& eigvecs,
const la::matrix<Float>& eigvals) const {
INFO("convert results to float64");
const auto s_f64 = la::astype<double>(s);
const auto eigvals_f64 = la::astype<double>(eigvals);
const auto eigvecs_f64 = la::astype<double>(eigvecs);
std::int64_t row_count = s.get_row_count();
std::int64_t column_count = s.get_column_count();
const Float* data = s.get_data();

Eigen::Matrix<Float, Eigen::Dynamic, Eigen::Dynamic> eigen_matrix(row_count, column_count);
for (int i = 0; i < eigen_matrix.rows(); ++i) {
for (int j = 0; j < eigen_matrix.cols(); ++j) {
eigen_matrix(i, j) = data[i * column_count + j];
}
}

Eigen::SelfAdjointEigenSolver<Eigen::Matrix<Float, Eigen::Dynamic, Eigen::Dynamic>> es(
eigen_matrix);

auto eigenvalues = es.eigenvalues().real();
INFO("oneDAL eigvals vs Eigen eigvals");
la::enumerate_linear(eigvals_f64, [&](std::int64_t i, Float x) {
REQUIRE(abs(eigvals_f64.get(i) - eigenvalues(i)) < 0.1);
});

INFO("oneDAL eigvectors vs Eigen eigvectors");
auto eigenvectors = es.eigenvectors().real();

const double* eigenvec_ptr = eigvecs_f64.get_data();
//TODO: investigate Eigen classes and align checking between oneDAL and Eigen classes.
for (int j = 0; j < eigvecs.get_column_count(); ++j) {
auto column_eigen = eigenvectors.col(j);
for (int i = 0; i < eigvecs.get_row_count(); ++i) {
REQUIRE((abs(eigenvec_ptr[j * row_count + i]) - abs(column_eigen(i))) < 0.1);
}
}
}

void check_eigvals_are_ascending(const la::matrix<Float>& eigvals) const {
INFO("check eigenvalues order is ascending");
la::enumerate_linear(eigvals, [&](std::int64_t i, Float x) {
Expand Down Expand Up @@ -158,14 +198,15 @@ TEMPLATE_LIST_TEST_M(syevd_test, "test syevd with pos def matrix", "[sym_eigvals

this->check_eigvals_definition(s, eigenvectors, eigenvalues);
this->check_eigvals_are_ascending(eigenvalues);
this->check_eigvals_with_eigen(s, eigenvectors, eigenvalues);
}

TEMPLATE_LIST_TEST_M(syevd_test, "test syevd with pos def matrix 2", "[sym_eigvals]", eigen_types) {
const auto s = this->generate_symmetric_positive();
// TEMPLATE_LIST_TEST_M(syevd_test, "test syevd with pos def matrix 2", "[sym_eigvals]", eigen_types) {
// const auto s = this->generate_symmetric_positive();

const auto [eigenvectors, eigenvalues] = this->call_sym_eigvals_inplace_descending(s);
// const auto [eigenvectors, eigenvalues] = this->call_sym_eigvals_inplace_descending(s);

this->check_eigvals_are_descending(eigenvalues);
}
// this->check_eigvals_are_descending(eigenvalues);
// }

} // namespace oneapi::dal::backend::primitives::test
1 change: 1 addition & 0 deletions cpp/oneapi/dal/test/engine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dal_test_module(
extra_deps = [
"@boost//:boost",
"@catch2//:catch2",
"@eigen//:eigen",
"@fmt//:fmt",
],
)
Expand Down
3 changes: 3 additions & 0 deletions cpp/oneapi/dal/test/engine/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
//Necessary headers from boost
#include <boost/process.hpp>

#include <Eigen/Dense>
#include <Eigen/Eigenvalues>

#include "oneapi/dal/train.hpp"
#include "oneapi/dal/infer.hpp"
#include "oneapi/dal/compute.hpp"
Expand Down
5 changes: 2 additions & 3 deletions dev/bazel/deps/boost.tpl.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@ package(default_visibility = ["//visibility:public"])
cc_library(
name = "boost",
srcs = glob([
"boost/libs/serialization/src/**/*.cpp",
"boost/libs/libboost*.a",
"libs/libboost*.a",
]),
hdrs = glob([
"boost/**/*.h",
"boost/**/*.hpp",
"boost/**/*.ipp",
]),
includes = [
"boost",
".",
],
visibility = ["//visibility:public"],
)
Expand Down
8 changes: 8 additions & 0 deletions dev/bazel/deps/eigen.tpl.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package(default_visibility = ["//visibility:public"])

cc_library(
name = "eigen",
hdrs = glob(["Eigen/**"]),
includes = [""],
visibility = ["//visibility:public"],
)

0 comments on commit 4c88001

Please sign in to comment.