Skip to content

Commit

Permalink
[SYCL][COMPAT] Add extend_sh* to syclcompat (#13727)
Browse files Browse the repository at this point in the history
Adds `extend_sh*` math operators to SYCLcompat.

This PR includes testing for all the different permutations of the
functions.
  • Loading branch information
AidanBeltonS authored May 10, 2024
1 parent fea7e77 commit cbed0d7
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 0 deletions.
196 changes: 196 additions & 0 deletions sycl/include/syclcompat/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,22 @@ struct sub_sat {
}
};

namespace detail {
struct shift_left {
template <typename T>
auto operator()(const T x, const uint32_t offset) const {
return x << offset;
}
};

struct shift_right {
template <typename T>
auto operator()(const T x, const uint32_t offset) const {
return x >> offset;
}
};
} // namespace detail

/// Compute vectorized binary operation value for two values, with each value
/// treated as a vector type \p VecT.
/// \tparam [in] VecT The type of the vector
Expand Down Expand Up @@ -1002,4 +1018,184 @@ inline constexpr RetT extend_max_sat(AT a, BT b, CT c,
return detail::extend_binary<RetT, true>(a, b, c, maximum(), second_op);
}

/// Extend \p a and \p b to 33 bit and return a << clamp(b, 0, 32).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \returns a << clamp(b, 0, 32)
template <typename RetT, typename T>
inline constexpr RetT extend_shl_clamp(T a, uint32_t b) {
return detail::extend_binary<RetT, false>(a, sycl::clamp(b, 0u, 32u),
detail::shift_left());
}

/// Extend Inputs to 33 bit, and return second_op(a << clamp(b, 0, 32), c).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \param [in] c The value to merge
/// \param [in] second_op The operation to do with the third value
/// \returns second_op(a << clamp(b, 0, 32), c)
template <typename RetT, typename T, typename BinaryOperation>
inline constexpr RetT extend_shl_clamp(T a, uint32_t b, uint32_t c,
BinaryOperation second_op) {
return detail::extend_binary<RetT, false>(a, sycl::clamp(b, 0u, 32u), c,
detail::shift_left(), second_op);
}

/// Extend \p a and \p b to 33 bit and return sat(a << clamp(b, 0, 32)).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \returns sat(a << clamp(b, 0, 32))
template <typename RetT, typename T>
inline constexpr RetT extend_shl_sat_clamp(T a, uint32_t b) {
return detail::extend_binary<RetT, true>(a, sycl::clamp(b, 0u, 32u),
detail::shift_left());
}

/// Extend Inputs to 33 bit, and return second_op(sat(a << clamp(b, 0, 32)), c).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \param [in] c The value to merge
/// \param [in] second_op The operation to do with the third value
/// \returns second_op(sat(a << clamp(b, 0, 32)), c)
template <typename RetT, typename T, typename BinaryOperation>
inline constexpr RetT extend_shl_sat_clamp(T a, uint32_t b, uint32_t c,
BinaryOperation second_op) {
return detail::extend_binary<RetT, true>(a, sycl::clamp(b, 0u, 32u), c,
detail::shift_left(), second_op);
}

/// Extend \p a and \p b to 33 bit and return a << (b & 0x1F).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \returns a << (b & 0x1F)
template <typename RetT, typename T>
inline constexpr RetT extend_shl_wrap(T a, uint32_t b) {
return detail::extend_binary<RetT, false>(a, b & 0x1F, detail::shift_left());
}

/// Extend Inputs to 33 bit, and return second_op(a << (b & 0x1F), c).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \param [in] c The value to merge
/// \param [in] second_op The operation to do with the third value
/// \returns second_op(a << (b & 0x1F), c)
template <typename RetT, typename T, typename BinaryOperation>
inline constexpr RetT extend_shl_wrap(T a, uint32_t b, uint32_t c,
BinaryOperation second_op) {
return detail::extend_binary<RetT, false>(a, b & 0x1F, c,
detail::shift_left(), second_op);
}

/// Extend \p a and \p b to 33 bit and return sat(a << (b & 0x1F)).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \returns sat(a << (b & 0x1F))
template <typename RetT, typename T>
inline constexpr RetT extend_shl_sat_wrap(T a, uint32_t b) {
return detail::extend_binary<RetT, true>(a, b & 0x1F, detail::shift_left());
}

/// Extend Inputs to 33 bit, and return second_op(sat(a << (b & 0x1F)), c).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \param [in] c The value to merge
/// \param [in] second_op The operation to do with the third value
/// \returns second_op(sat(a << (b & 0x1F)), c)
template <typename RetT, typename T, typename BinaryOperation>
inline constexpr RetT extend_shl_sat_wrap(T a, uint32_t b, uint32_t c,
BinaryOperation second_op) {
return detail::extend_binary<RetT, true>(a, b & 0x1F, c, detail::shift_left(),
second_op);
}

/// Extend \p a and \p b to 33 bit and return a >> clamp(b, 0, 32).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \returns a >> clamp(b, 0, 32)
template <typename RetT, typename T>
inline constexpr RetT extend_shr_clamp(T a, uint32_t b) {
return detail::extend_binary<RetT, false>(a, sycl::clamp(b, 0u, 32u),
detail::shift_right());
}

/// Extend Inputs to 33 bit, and return second_op(a >> clamp(b, 0, 32), c).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \param [in] c The value to merge
/// \param [in] second_op The operation to do with the third value
/// \returns second_op(a >> clamp(b, 0, 32), c)
template <typename RetT, typename T, typename BinaryOperation>
inline constexpr RetT extend_shr_clamp(T a, uint32_t b, uint32_t c,
BinaryOperation second_op) {
return detail::extend_binary<RetT, false>(a, sycl::clamp(b, 0u, 32u), c,
detail::shift_right(), second_op);
}

/// Extend \p a and \p b to 33 bit and return sat(a >> clamp(b, 0, 32)).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \returns sat(a >> clamp(b, 0, 32))
template <typename RetT, typename T>
inline constexpr RetT extend_shr_sat_clamp(T a, uint32_t b) {
return detail::extend_binary<RetT, true>(a, sycl::clamp(b, 0u, 32u),
detail::shift_right());
}

/// Extend Inputs to 33 bit, and return second_op(sat(a >> clamp(b, 0, 32)), c).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \param [in] c The value to merge
/// \param [in] second_op The operation to do with the third value
/// \returns second_op(sat(a >> clamp(b, 0, 32)), c)
template <typename RetT, typename T, typename BinaryOperation>
inline constexpr RetT extend_shr_sat_clamp(T a, uint32_t b, uint32_t c,
BinaryOperation second_op) {
return detail::extend_binary<RetT, true>(a, sycl::clamp(b, 0u, 32u), c,
detail::shift_right(), second_op);
}

/// Extend \p a and \p b to 33 bit and return a >> (b & 0x1F).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \returns a >> (b & 0x1F)
template <typename RetT, typename T>
inline constexpr RetT extend_shr_wrap(T a, uint32_t b) {
return detail::extend_binary<RetT, false>(a, b & 0x1F, detail::shift_right());
}

/// Extend Inputs to 33 bit, and return second_op(a >> (b & 0x1F), c).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \param [in] c The value to merge
/// \param [in] second_op The operation to do with the third value
/// \returns second_op(a >> (b & 0x1F), c)
template <typename RetT, typename T, typename BinaryOperation>
inline constexpr RetT extend_shr_wrap(T a, uint32_t b, uint32_t c,
BinaryOperation second_op) {
return detail::extend_binary<RetT, false>(a, b & 0x1F, c,
detail::shift_right(), second_op);
}

/// Extend \p a and \p b to 33 bit and return sat(a >> (b & 0x1F)).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \returns sat(a >> (b & 0x1F))
template <typename RetT, typename T>
inline constexpr RetT extend_shr_sat_wrap(T a, uint32_t b) {
return detail::extend_binary<RetT, true>(a, b & 0x1F, detail::shift_right());
}

/// Extend Inputs to 33 bit, and return second_op(sat(a >> (b & 0x1F)), c).
/// \param [in] a The source value
/// \param [in] b The offset to shift
/// \param [in] c The value to merge
/// \param [in] second_op The operation to do with the third value
/// \returns second_op(sat(a >> (b & 0x1F)), c)
template <typename RetT, typename T, typename BinaryOperation>
inline constexpr RetT extend_shr_sat_wrap(T a, uint32_t b, uint32_t c,
BinaryOperation second_op) {
return detail::extend_binary<RetT, true>(a, b & 0x1F, c,
detail::shift_right(), second_op);
}

} // namespace syclcompat
107 changes: 107 additions & 0 deletions sycl/test-e2e/syclcompat/math/math_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,77 @@ std::pair<const char *, int> vmax() {
return {nullptr, 0};
}

template <typename Tp> struct scale {
Tp operator()(Tp val, Tp scaler) { return val * scaler; }
};

template <typename Tp> struct noop {
Tp operator()(Tp val, Tp /*scaler*/) { return val; }
};

std::pair<const char *, int> shl_clamp() {
CHECK(syclcompat::extend_shl_clamp<int32_t>(3, 4), 48);
CHECK(syclcompat::extend_shl_clamp<int32_t>(6, 33), 0);
CHECK(syclcompat::extend_shl_clamp<int32_t>(3, 4, 4, scale<int32_t>()), 192);
CHECK(syclcompat::extend_shl_clamp<int32_t>(3, 4, 4, noop<int32_t>()), 48);
CHECK(syclcompat::extend_shl_sat_clamp<int8_t>(9, 5), 127);
CHECK(syclcompat::extend_shl_sat_clamp<int8_t>(-9, 5), -128);
CHECK(syclcompat::extend_shl_sat_clamp<int8_t>(9, 5, -1, scale<int8_t>()),
-127);
CHECK(syclcompat::extend_shl_sat_clamp<int8_t>(9, 5, -1, noop<int8_t>()),
127);

return {nullptr, 0};
}

std::pair<const char *, int> shl_wrap() {
CHECK(syclcompat::extend_shl_wrap<int32_t>(3, 4), 48);
CHECK(syclcompat::extend_shl_wrap<int32_t>(6, 32), 6);
CHECK(syclcompat::extend_shl_wrap<int32_t>(6, 33), 12);
CHECK(syclcompat::extend_shl_wrap<int32_t>(6, 64), 6);
CHECK(syclcompat::extend_shl_wrap<int32_t>(3, 4, 4, scale<int32_t>()), 192);
CHECK(syclcompat::extend_shl_wrap<int32_t>(6, 32, 4, noop<int32_t>()), 6);
CHECK(syclcompat::extend_shl_sat_wrap<int8_t>(9, 5), 127);
CHECK(syclcompat::extend_shl_sat_wrap<int8_t>(-9, 5), -128);
CHECK(syclcompat::extend_shl_sat_wrap<int8_t>(9, 5, -1, scale<int8_t>()),
-127);
CHECK(syclcompat::extend_shl_sat_wrap<int8_t>(9, 5, -1, noop<int8_t>()), 127);

return {nullptr, 0};
}

std::pair<const char *, int> shr_clamp() {
CHECK(syclcompat::extend_shr_clamp<int32_t>(128, 5), 4);
CHECK(syclcompat::extend_shr_clamp<int32_t>(INT32MAX, 33), 0);
CHECK(syclcompat::extend_shr_clamp<int32_t>(128, 5, 4, scale<int32_t>()), 16);
CHECK(syclcompat::extend_shr_clamp<int32_t>(128, 5, 4, noop<int32_t>()), 4);
CHECK(syclcompat::extend_shr_sat_clamp<int8_t>(512, 1), 127);
CHECK(syclcompat::extend_shr_sat_clamp<int8_t>(-512, 1), -128);
CHECK(syclcompat::extend_shr_sat_clamp<int8_t>(512, 1, -1, scale<int8_t>()),
-127);
CHECK(syclcompat::extend_shr_sat_clamp<int8_t>(512, 1, -1, noop<int8_t>()),
127);

return {nullptr, 0};
}

std::pair<const char *, int> shr_wrap() {
CHECK(syclcompat::extend_shr_wrap<int32_t>(128, 5), 4);
CHECK(syclcompat::extend_shr_wrap<int32_t>(128, 32), 128);
CHECK(syclcompat::extend_shr_wrap<int32_t>(128, 33), 64);
CHECK(syclcompat::extend_shr_wrap<int32_t>(128, 64), 128);
CHECK(syclcompat::extend_shr_wrap<int32_t>(128, 5, 4, scale<int32_t>()), 16);
CHECK(syclcompat::extend_shr_wrap<int32_t>(128, 5, 4, noop<int32_t>()), 4);
CHECK(syclcompat::extend_shr_sat_wrap<int8_t>(512, 1), 127);
CHECK(syclcompat::extend_shr_sat_wrap<int8_t>(-512, 1), -128);
CHECK(syclcompat::extend_shr_sat_wrap<int8_t>(512, 1, -1, scale<int8_t>()),
-127);
CHECK(syclcompat::extend_shr_sat_wrap<int8_t>(512, 1, -1, noop<int8_t>()),
127);

return {nullptr, 0};
}

void test(const sycl::stream &s, int *ec) {
{
auto res = vadd();
Expand Down Expand Up @@ -165,6 +236,42 @@ void test(const sycl::stream &s, int *ec) {
}
s << "vmax check passed!\n";
}
{
auto res = shl_clamp();
if (res.first) {
s << res.first << " = " << res.second << " check failed!\n";
*ec = 6;
return;
}
s << "shl_clamp check passed!\n";
}
{
auto res = shl_wrap();
if (res.first) {
s << res.first << " = " << res.second << " check failed!\n";
*ec = 7;
return;
}
s << "shl_wrap check passed!\n";
}
{
auto res = shr_clamp();
if (res.first) {
s << res.first << " = " << res.second << " check failed!\n";
*ec = 8;
return;
}
s << "shr_clamp check passed!\n";
}
{
auto res = shr_wrap();
if (res.first) {
s << res.first << " = " << res.second << " check failed!\n";
*ec = 9;
return;
}
s << "shr_wrap check passed!\n";
}
*ec = 0;
}

Expand Down

0 comments on commit cbed0d7

Please sign in to comment.