Skip to content

Commit

Permalink
Disable vectorized bitwise_invert for boolean inputs (#1681)
Browse files Browse the repository at this point in the history
* Removes `sycl::vec` overload in `BitwiseInvertFunctor`

This overload would cause sufficiently large boolean arrays to produce unexpected results when cast to another type

* Adds a test for fixed bitwise_invert behavior

* Re-enable vectorized `bitwise_invert` for integer types
  • Loading branch information
ndgrigorian authored May 16, 2024
1 parent c994666 commit ba09dd8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ template <typename argT, typename resT> struct BitwiseInvertFunctor

using is_constant = typename std::false_type;
// constexpr resT constant_value = resT{};
using supports_vec = typename std::true_type;
using supports_vec = typename std::negation<std::is_same<argT, bool>>;
using supports_sg_loadstore = typename std::true_type;
;

resT operator()(const argT &in) const
{
Expand All @@ -75,16 +74,7 @@ template <typename argT, typename resT> struct BitwiseInvertFunctor
template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in) const
{
if constexpr (std::is_same_v<argT, bool>) {
auto res_vec = !in;

using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
else {
return ~in;
}
return ~in;
}
};

Expand Down
10 changes: 10 additions & 0 deletions dpctl/tests/elementwise/test_bitwise_invert.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,13 @@ def test_bitwise_invert_order():
ar1 = dpt.zeros((40, 40), dtype="i4", order="C")[:20, ::-2].mT
r4 = dpt.bitwise_invert(ar1, order="K")
assert r4.strides == (-1, 20)


def test_bitwise_invert_large_boolean():
get_queue_or_skip()

x = dpt.tril(dpt.ones((32, 32), dtype="?"), k=-1)
res = dpt.astype(dpt.bitwise_invert(x), "i4")

assert dpt.all(res >= 0)
assert dpt.all(res <= 1)

0 comments on commit ba09dd8

Please sign in to comment.