From 824bd02f51a6c6a58442bba4ee2ea83c21142e5d Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Mon, 16 Dec 2024 14:55:29 -0300 Subject: [PATCH 1/8] Add BitMask utils Signed-off-by: martinvuyk --- stdlib/src/bit/mask.mojo | 163 +++++++++++++++++++++++++++++++++ stdlib/test/bit/test_mask.mojo | 108 ++++++++++++++++++++++ 2 files changed, 271 insertions(+) create mode 100644 stdlib/src/bit/mask.mojo create mode 100644 stdlib/test/bit/test_mask.mojo diff --git a/stdlib/src/bit/mask.mojo b/stdlib/src/bit/mask.mojo new file mode 100644 index 0000000000..5cbc1a4524 --- /dev/null +++ b/stdlib/src/bit/mask.mojo @@ -0,0 +1,163 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2024, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""Provides functions for bit manipulation. + +You can import these APIs from the `bit` package. For example: + +```mojo +from bit.utils import count_leading_zeros +``` +""" + +from os import abort +from sys.info import bitwidthof + + +struct BitMask: + """Utils for building bitmasks.""" + + alias EQ = 0 + """Value for `==`.""" + alias NE = 1 + """Value for `!=`.""" + alias GT = 2 + """Value for `>`.""" + alias GE = 3 + """Value for `>=`.""" + alias LT = 4 + """Value for `<`.""" + alias LE = 5 + """Value for `<=`.""" + + @always_inline + @staticmethod + fn is_negative(value: Int) -> Int: + """Get a bitmask of whether the value is negative. + + Args: + value: The value to check. + + Returns: + A bitmask filled with `1` if the value is negative, filled with `0` + otherwise. + """ + return int(Self.is_negative(Scalar[DType.index](value))) + + @always_inline + @staticmethod + fn is_negative[D: DType](value: SIMD[D, _]) -> __type_of(value): + """Get a bitmask of whether the value is negative. + + Parameters: + D: The DType. + + Args: + value: The value to check. + + Returns: + A bitmask filled with `1` if the value is negative, filled with `0` + otherwise. + """ + constrained[ + D.is_integral() and D.is_signed(), + "This function is for signed integral types.", + ]() + return value >> (bitwidthof[D]() - 1) + + @always_inline + @staticmethod + fn is_true[ + D: DType, size: Int = 1 + ](value: SIMD[DType.bool, size]) -> SIMD[D, size]: + """Get a bitmask of whether the value is `True`. + + Parameters: + D: The DType. + size: The size of the SIMD vector. + + Args: + value: The value to check. + + Returns: + A bitmask filled with `1` if the value is `True`, filled with `0` + otherwise. + """ + return Self.is_false[D](~value) + + @always_inline + @staticmethod + fn is_false[ + D: DType, size: Int = 1 + ](value: SIMD[DType.bool, size]) -> SIMD[D, size]: + """Get a bitmask of whether the value is `False`. + + Parameters: + D: The DType. + size: The size of the SIMD vector. + + Args: + value: The value to check. + + Returns: + A bitmask filled with `1` if the value is `False`, filled with `0` + otherwise. + """ + return (value.cast[DType.int8]() - 1).cast[D]() + + @always_inline + @staticmethod + fn compare[ + D: DType, //, comp: Int + ](lhs: SIMD[D, _], rhs: __type_of(lhs)) -> __type_of(lhs): + """Get a bitmask of the comparison between the two values. + + Args: + lhs: The value to check. + rhs: The value to check. + + Returns: + A bitmask filled with `1` if the comparison is true, filled with `0` + otherwise. + """ + + @parameter + if comp == Self.EQ: + return Self.is_true[D](lhs == rhs) + elif comp == Self.NE: + return Self.is_true[D](lhs != rhs) + elif comp == Self.GT: + return Self.is_true[D](lhs > rhs) + elif comp == Self.GE: + return Self.is_true[D](lhs >= rhs) + elif comp == Self.LT: + return Self.is_true[D](lhs < rhs) + elif comp == Self.LE: + return Self.is_true[D](lhs <= rhs) + else: + constrained[False, "comparison operator value not found"]() + return abort[__type_of(lhs)]() + + @staticmethod + fn compare[D: DType, //, comp: Int](lhs: Int, rhs: Int) -> Int: + """Get a bitmask of the comparison between the two values. + + Args: + lhs: The value to check. + rhs: The value to check. + + Returns: + A bitmask filled with `1` if the comparison is true, filled with `0` + otherwise. + """ + alias S = Scalar[DType.index] + return int(Self.compare[comp=comp](S(lhs), S(rhs))) diff --git a/stdlib/test/bit/test_mask.mojo b/stdlib/test/bit/test_mask.mojo new file mode 100644 index 0000000000..8b198d72b0 --- /dev/null +++ b/stdlib/test/bit/test_mask.mojo @@ -0,0 +1,108 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2024, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +# RUN: %bare-mojo %s + +from testing import assert_equal +from bit.mask import BitMask +from sys.info import bitwidthof + + +def test_is_negative(): + alias dtypes = ( + DType.int8, + DType.int16, + DType.int32, + DType.int64, + DType.index, + ) + alias widths = (1, 2, 4, 8) + + @parameter + for i in range(len(dtypes)): + alias D = dtypes.get[i, DType]() + var last_value = 2 ** (bitwidthof[D]() - 1) - 1 + var values = List(1, 2, last_value - 1, last_value) + + @parameter + for j in range(len(widths)): + alias S = SIMD[D, widths.get[j, Int]()] + + for k in values: + assert_equal(S(-1), BitMask.is_negative(S(-k[]))) + assert_equal(S(0), BitMask.is_negative(S(k[]))) + + +def test_is_true(): + alias dtypes = ( + DType.int8, + DType.int16, + DType.int32, + DType.int64, + DType.index, + DType.uint8, + DType.uint16, + DType.uint32, + DType.uint64, + ) + alias widths = (1, 2, 4, 8) + + @parameter + for i in range(len(dtypes)): + alias D = dtypes.get[i, DType]() + + @parameter + for j in range(len(widths)): + alias w = widths.get[j, Int]() + alias B = SIMD[DType.bool, w] + assert_equal(SIMD[D, w](-1), BitMask.is_true[D](B(True))) + assert_equal(SIMD[D, w](0), BitMask.is_true[D](B(False))) + + +def test_compare(): + alias dtypes = ( + DType.int8, + DType.int16, + DType.int32, + DType.int64, + DType.index, + ) + alias widths = (1, 2, 4, 8) + + @parameter + for i in range(len(dtypes)): + alias D = dtypes.get[i, DType]() + var last_value = 2 ** (bitwidthof[D]() - 1) - 1 + var values = List(1, 2, last_value - 1, last_value) + + @parameter + for j in range(len(widths)): + alias S = SIMD[D, widths.get[j, Int]()] + + for k in values: + var s_k = S(k[]) + var s_k_1 = S(k[] - 1) + assert_equal(S(-1), BitMask.compare[BitMask.EQ](s_k, s_k)) + assert_equal(S(-1), BitMask.compare[BitMask.EQ](-s_k, -s_k)) + assert_equal(S(-1), BitMask.compare[BitMask.NE](s_k, s_k_1)) + assert_equal(S(-1), BitMask.compare[BitMask.NE](-s_k, s_k_1)) + assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k, s_k_1)) + assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k_1, -s_k)) + assert_equal(S(-1), BitMask.compare[BitMask.GE](-s_k, -s_k)) + assert_equal(S(-1), BitMask.compare[BitMask.LT](-s_k, s_k_1)) + assert_equal(S(-1), BitMask.compare[BitMask.LE](-s_k, -s_k)) + + +def main(): + test_is_negative() + test_is_true() + test_compare() From cf816bec5f586d0e1141cdc1e0f00824052d6322 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Mon, 16 Dec 2024 14:59:58 -0300 Subject: [PATCH 2/8] fix docstring Signed-off-by: martinvuyk --- stdlib/src/bit/mask.mojo | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stdlib/src/bit/mask.mojo b/stdlib/src/bit/mask.mojo index 5cbc1a4524..58a1f53828 100644 --- a/stdlib/src/bit/mask.mojo +++ b/stdlib/src/bit/mask.mojo @@ -121,6 +121,10 @@ struct BitMask: ](lhs: SIMD[D, _], rhs: __type_of(lhs)) -> __type_of(lhs): """Get a bitmask of the comparison between the two values. + Parameters: + D: The DType. + comp: The comparison operator, e.g. `BitMask.EQ`. + Args: lhs: The value to check. rhs: The value to check. From 391ffb45891b590fe6630c3fa90df183e3a27be1 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Mon, 16 Dec 2024 15:04:43 -0300 Subject: [PATCH 3/8] fix docstring Signed-off-by: martinvuyk --- stdlib/src/bit/mask.mojo | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/stdlib/src/bit/mask.mojo b/stdlib/src/bit/mask.mojo index 58a1f53828..7ebaec48f2 100644 --- a/stdlib/src/bit/mask.mojo +++ b/stdlib/src/bit/mask.mojo @@ -152,9 +152,12 @@ struct BitMask: return abort[__type_of(lhs)]() @staticmethod - fn compare[D: DType, //, comp: Int](lhs: Int, rhs: Int) -> Int: + fn compare[comp: Int](lhs: Int, rhs: Int) -> Int: """Get a bitmask of the comparison between the two values. + Parameters: + comp: The comparison operator, e.g. `BitMask.EQ`. + Args: lhs: The value to check. rhs: The value to check. From d69f0d9137176b0787bcd4f231bba2fcd47141a7 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Mon, 16 Dec 2024 15:44:18 -0300 Subject: [PATCH 4/8] fix docstrings Signed-off-by: martinvuyk --- stdlib/src/bit/mask.mojo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/src/bit/mask.mojo b/stdlib/src/bit/mask.mojo index 7ebaec48f2..f14bfc3d0d 100644 --- a/stdlib/src/bit/mask.mojo +++ b/stdlib/src/bit/mask.mojo @@ -10,12 +10,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ===----------------------------------------------------------------------=== # -"""Provides functions for bit manipulation. +"""Provides functions for bit masks. You can import these APIs from the `bit` package. For example: ```mojo -from bit.utils import count_leading_zeros +from bit.mask import BitMask ``` """ From 654e243c1daf662ee8a8a4668811cf87f278e290 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 19 Dec 2024 19:50:17 -0300 Subject: [PATCH 5/8] fix remove use of subtraction Signed-off-by: martinvuyk --- stdlib/src/bit/mask.mojo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/src/bit/mask.mojo b/stdlib/src/bit/mask.mojo index f14bfc3d0d..39ce700ba6 100644 --- a/stdlib/src/bit/mask.mojo +++ b/stdlib/src/bit/mask.mojo @@ -92,7 +92,7 @@ struct BitMask: A bitmask filled with `1` if the value is `True`, filled with `0` otherwise. """ - return Self.is_false[D](~value) + return (-(value.cast[DType.int8]())).cast[D]() @always_inline @staticmethod @@ -112,7 +112,7 @@ struct BitMask: A bitmask filled with `1` if the value is `False`, filled with `0` otherwise. """ - return (value.cast[DType.int8]() - 1).cast[D]() + return Self.is_true[D](~value) @always_inline @staticmethod From f03b7cefdaa8f4a5c0f535cb3f9d96f34b49e1cf Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Sat, 1 Feb 2025 11:42:21 -0300 Subject: [PATCH 6/8] remove tuple get usage Signed-off-by: martinvuyk --- stdlib/src/bit/mask.mojo | 4 ++-- stdlib/test/bit/test_mask.mojo | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/stdlib/src/bit/mask.mojo b/stdlib/src/bit/mask.mojo index 39ce700ba6..3c72a98a52 100644 --- a/stdlib/src/bit/mask.mojo +++ b/stdlib/src/bit/mask.mojo @@ -51,7 +51,7 @@ struct BitMask: A bitmask filled with `1` if the value is negative, filled with `0` otherwise. """ - return int(Self.is_negative(Scalar[DType.index](value))) + return Int(Self.is_negative(Scalar[DType.index](value))) @always_inline @staticmethod @@ -167,4 +167,4 @@ struct BitMask: otherwise. """ alias S = Scalar[DType.index] - return int(Self.compare[comp=comp](S(lhs), S(rhs))) + return Int(Self.compare[comp=comp](S(lhs), S(rhs))) diff --git a/stdlib/test/bit/test_mask.mojo b/stdlib/test/bit/test_mask.mojo index 8b198d72b0..9f74cf3947 100644 --- a/stdlib/test/bit/test_mask.mojo +++ b/stdlib/test/bit/test_mask.mojo @@ -29,13 +29,13 @@ def test_is_negative(): @parameter for i in range(len(dtypes)): - alias D = dtypes.get[i, DType]() + alias D = dtypes[i] var last_value = 2 ** (bitwidthof[D]() - 1) - 1 var values = List(1, 2, last_value - 1, last_value) @parameter for j in range(len(widths)): - alias S = SIMD[D, widths.get[j, Int]()] + alias S = SIMD[D, widths[j]] for k in values: assert_equal(S(-1), BitMask.is_negative(S(-k[]))) @@ -58,11 +58,11 @@ def test_is_true(): @parameter for i in range(len(dtypes)): - alias D = dtypes.get[i, DType]() + alias D = dtypes[i] @parameter for j in range(len(widths)): - alias w = widths.get[j, Int]() + alias w = widths[j] alias B = SIMD[DType.bool, w] assert_equal(SIMD[D, w](-1), BitMask.is_true[D](B(True))) assert_equal(SIMD[D, w](0), BitMask.is_true[D](B(False))) @@ -80,13 +80,13 @@ def test_compare(): @parameter for i in range(len(dtypes)): - alias D = dtypes.get[i, DType]() + alias D = dtypes[i] var last_value = 2 ** (bitwidthof[D]() - 1) - 1 var values = List(1, 2, last_value - 1, last_value) @parameter for j in range(len(widths)): - alias S = SIMD[D, widths.get[j, Int]()] + alias S = SIMD[D, widths[j]] for k in values: var s_k = S(k[]) From eea09a34bbf027fa795681ff2a6505919b4debd5 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Fri, 28 Feb 2025 12:59:49 -0300 Subject: [PATCH 7/8] fix remove comparison to reduce scope of PR Signed-off-by: martinvuyk --- mojo/stdlib/src/bit/mask.mojo | 68 ----------------------------- mojo/stdlib/test/bit/test_mask.mojo | 18 ++++---- 2 files changed, 9 insertions(+), 77 deletions(-) diff --git a/mojo/stdlib/src/bit/mask.mojo b/mojo/stdlib/src/bit/mask.mojo index 934e3543a4..809ee61d8d 100644 --- a/mojo/stdlib/src/bit/mask.mojo +++ b/mojo/stdlib/src/bit/mask.mojo @@ -26,19 +26,6 @@ from sys.info import bitwidthof struct BitMask: """Utils for building bitmasks.""" - alias EQ = 0 - """Value for `==`.""" - alias NE = 1 - """Value for `!=`.""" - alias GT = 2 - """Value for `>`.""" - alias GE = 3 - """Value for `>=`.""" - alias LT = 4 - """Value for `<`.""" - alias LE = 5 - """Value for `<=`.""" - @always_inline @staticmethod fn is_negative(value: Int) -> Int: @@ -113,58 +100,3 @@ struct BitMask: otherwise. """ return Self.is_true[D](~value) - - @always_inline - @staticmethod - fn compare[ - D: DType, //, comp: Int - ](lhs: SIMD[D, _], rhs: __type_of(lhs)) -> __type_of(lhs): - """Get a bitmask of the comparison between the two values. - - Parameters: - D: The DType. - comp: The comparison operator, e.g. `BitMask.EQ`. - - Args: - lhs: The value to check. - rhs: The value to check. - - Returns: - A bitmask filled with `1` if the comparison is true, filled with `0` - otherwise. - """ - - @parameter - if comp == Self.EQ: - return Self.is_true[D](lhs == rhs) - elif comp == Self.NE: - return Self.is_true[D](lhs != rhs) - elif comp == Self.GT: - return Self.is_true[D](lhs > rhs) - elif comp == Self.GE: - return Self.is_true[D](lhs >= rhs) - elif comp == Self.LT: - return Self.is_true[D](lhs < rhs) - elif comp == Self.LE: - return Self.is_true[D](lhs <= rhs) - else: - constrained[False, "comparison operator value not found"]() - return abort[__type_of(lhs)]() - - @staticmethod - fn compare[comp: Int](lhs: Int, rhs: Int) -> Int: - """Get a bitmask of the comparison between the two values. - - Parameters: - comp: The comparison operator, e.g. `BitMask.EQ`. - - Args: - lhs: The value to check. - rhs: The value to check. - - Returns: - A bitmask filled with `1` if the comparison is true, filled with `0` - otherwise. - """ - alias S = Scalar[DType.index] - return Int(Self.compare[comp=comp](S(lhs), S(rhs))) diff --git a/mojo/stdlib/test/bit/test_mask.mojo b/mojo/stdlib/test/bit/test_mask.mojo index 55daafee68..25e2cc2cdc 100644 --- a/mojo/stdlib/test/bit/test_mask.mojo +++ b/mojo/stdlib/test/bit/test_mask.mojo @@ -91,15 +91,15 @@ def test_compare(): for k in values: var s_k = S(k[]) var s_k_1 = S(k[] - 1) - assert_equal(S(-1), BitMask.compare[BitMask.EQ](s_k, s_k)) - assert_equal(S(-1), BitMask.compare[BitMask.EQ](-s_k, -s_k)) - assert_equal(S(-1), BitMask.compare[BitMask.NE](s_k, s_k_1)) - assert_equal(S(-1), BitMask.compare[BitMask.NE](-s_k, s_k_1)) - assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k, s_k_1)) - assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k_1, -s_k)) - assert_equal(S(-1), BitMask.compare[BitMask.GE](-s_k, -s_k)) - assert_equal(S(-1), BitMask.compare[BitMask.LT](-s_k, s_k_1)) - assert_equal(S(-1), BitMask.compare[BitMask.LE](-s_k, -s_k)) + assert_equal(S(-1), BitMask.is_true[D](s_k == s_k)) + assert_equal(S(-1), BitMask.is_true[D](-s_k == -s_k)) + assert_equal(S(-1), BitMask.is_true[D](s_k != s_k_1)) + assert_equal(S(-1), BitMask.is_true[D](-s_k != s_k_1)) + assert_equal(S(-1), BitMask.is_true[D](s_k > s_k_1)) + assert_equal(S(-1), BitMask.is_true[D](s_k_1 > -s_k)) + assert_equal(S(-1), BitMask.is_true[D](-s_k >= -s_k)) + assert_equal(S(-1), BitMask.is_true[D](-s_k < s_k_1)) + assert_equal(S(-1), BitMask.is_true[D](-s_k <= -s_k)) def main(): From 8bf10ead75a19a2b944001b53a89e35254e6a118 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Fri, 28 Feb 2025 13:41:36 -0300 Subject: [PATCH 8/8] retry github CI Signed-off-by: martinvuyk