diff --git a/sycl/test-e2e/DeviceLib/exp/exp-std-complex-edge-cases.hpp b/sycl/test-e2e/DeviceLib/exp/exp-std-complex-edge-cases.hpp index d807a461e574..992cd6147b53 100644 --- a/sycl/test-e2e/DeviceLib/exp/exp-std-complex-edge-cases.hpp +++ b/sycl/test-e2e/DeviceLib/exp/exp-std-complex-edge-cases.hpp @@ -199,12 +199,18 @@ template bool test() { sycl::buffer> data(testcases, sycl::range{N}); sycl::buffer> results(sycl::range{N}); + sycl::buffer> exp_conj(sycl::range{N}); + sycl::buffer> conj_exp(sycl::range{N}); q.submit([&](sycl::handler &cgh) { sycl::accessor acc_data(data, cgh, sycl::read_only); - sycl::accessor acc(results, cgh, sycl::write_only); + sycl::accessor acc_results(results, cgh, sycl::write_only); + sycl::accessor acc_exp_conj(exp_conj, cgh, sycl::write_only); + sycl::accessor acc_conj_exp(conj_exp, cgh, sycl::write_only); cgh.parallel_for(sycl::range{N}, [=](sycl::item<1> it) { - acc[it] = std::exp(acc_data[it]); + acc_results[it] = std::exp(acc_data[it]); + acc_exp_conj[it] = std::exp(std::conj(acc_data[it])); + acc_conj_exp[it] = std::conj(std::exp(acc_data[it])); }); }).wait_and_throw(); @@ -219,9 +225,18 @@ template bool test() { // Based on https://en.cppreference.com/w/cpp/numeric/complex/exp // z below refers to the argument passed to std::exp(complex) - sycl::host_accessor acc(results); + sycl::host_accessor acc_results(results); + sycl::host_accessor acc_exp_conj(exp_conj); + sycl::host_accessor acc_conj_exp(conj_exp); for (unsigned i = 0; i < N; ++i) { - std::complex r = acc[i]; + // std::exp(std::conj(z)) == std::conj(std::exp(z)) + // NAN is not equal to NAN in floating-point arithmetic, therefore compare + // only results without NAN + if (!std::isnan(acc_exp_conj[i].real()) && + !std::isnan(acc_exp_conj[i].imag())) + CHECK(acc_exp_conj[i] == acc_conj_exp[i], passed, i); + + std::complex r = acc_results[i]; // If z is (+/-0, +0), the result is (1, +0) if (testcases[i].real() == 0 && testcases[i].imag() == 0 && !std::signbit(testcases[i].imag())) { @@ -247,6 +262,33 @@ template bool test() { CHECK(r.imag() == 0, passed, i); CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), passed, i); + // If z is (-inf, y) (for any finite y), the result is +0cis(y) where + // cis(y) is cos(y) + isin(y) + } else if (std::isinf(testcases[i].real()) && + std::signbit(testcases[i].real()) && + std::isfinite(testcases[i].imag())) { + CHECK(r.real() == 0, passed, i) + CHECK(std::signbit(r.real()) == + std::signbit(std::cos(testcases[i].imag())), + passed, i) + CHECK(r.imag() == 0, passed, i) + CHECK(std::signbit(r.imag()) == + std::signbit(std::sin(testcases[i].imag())), + passed, i) + // If z is (+inf, y) (for any finite nonzero y), the result is +∞cis(y) + // where cis(y) is cos(y) + isin(y) + } else if (std::isinf(testcases[i].real()) && + !std::signbit(testcases[i].real()) && + std::isfinite(testcases[i].imag()) && + testcases[i].imag() != 0) { + CHECK(std::isinf(r.real()), passed, i) + CHECK(std::signbit(r.real()) == + std::signbit(std::cos(testcases[i].imag())), + passed, i) + CHECK(std::isinf(r.imag()), passed, i) + CHECK(std::signbit(r.imag()) == + std::signbit(std::sin(testcases[i].imag())), + passed, i) // If z is (-inf, +inf), the result is (+/-0, +/-0) (signs are // unspecified) } else if (std::isinf(testcases[i].real()) && testcases[i].real() < 0 &&