diff --git a/mojo/stdlib/src/bit/mask.mojo b/mojo/stdlib/src/bit/mask.mojo new file mode 100644 index 0000000000..809ee61d8d --- /dev/null +++ b/mojo/stdlib/src/bit/mask.mojo @@ -0,0 +1,102 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""Provides functions for bit masks. + +You can import these APIs from the `bit` package. For example: + +```mojo +from bit.mask import BitMask +``` +""" + +from os import abort +from sys.info import bitwidthof + + +struct BitMask: + """Utils for building bitmasks.""" + + @always_inline + @staticmethod + fn is_negative(value: Int) -> Int: + """Get a bitmask of whether the value is negative. + + Args: + value: The value to check. + + Returns: + A bitmask filled with `1` if the value is negative, filled with `0` + otherwise. + """ + return Int(Self.is_negative(Scalar[DType.index](value))) + + @always_inline + @staticmethod + fn is_negative[D: DType](value: SIMD[D, _]) -> __type_of(value): + """Get a bitmask of whether the value is negative. + + Parameters: + D: The DType. + + Args: + value: The value to check. + + Returns: + A bitmask filled with `1` if the value is negative, filled with `0` + otherwise. + """ + constrained[ + D.is_integral() and D.is_signed(), + "This function is for signed integral types.", + ]() + return value >> (bitwidthof[D]() - 1) + + @always_inline + @staticmethod + fn is_true[ + D: DType, size: Int = 1 + ](value: SIMD[DType.bool, size]) -> SIMD[D, size]: + """Get a bitmask of whether the value is `True`. + + Parameters: + D: The DType. + size: The size of the SIMD vector. + + Args: + value: The value to check. + + Returns: + A bitmask filled with `1` if the value is `True`, filled with `0` + otherwise. + """ + return (-(value.cast[DType.int8]())).cast[D]() + + @always_inline + @staticmethod + fn is_false[ + D: DType, size: Int = 1 + ](value: SIMD[DType.bool, size]) -> SIMD[D, size]: + """Get a bitmask of whether the value is `False`. + + Parameters: + D: The DType. + size: The size of the SIMD vector. + + Args: + value: The value to check. + + Returns: + A bitmask filled with `1` if the value is `False`, filled with `0` + otherwise. + """ + return Self.is_true[D](~value) diff --git a/mojo/stdlib/test/bit/test_mask.mojo b/mojo/stdlib/test/bit/test_mask.mojo new file mode 100644 index 0000000000..25e2cc2cdc --- /dev/null +++ b/mojo/stdlib/test/bit/test_mask.mojo @@ -0,0 +1,108 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +# RUN: %bare-mojo %s + +from testing import assert_equal +from bit.mask import BitMask +from sys.info import bitwidthof + + +def test_is_negative(): + alias dtypes = ( + DType.int8, + DType.int16, + DType.int32, + DType.int64, + DType.index, + ) + alias widths = (1, 2, 4, 8) + + @parameter + for i in range(len(dtypes)): + alias D = dtypes[i] + var last_value = 2 ** (bitwidthof[D]() - 1) - 1 + var values = List(1, 2, last_value - 1, last_value) + + @parameter + for j in range(len(widths)): + alias S = SIMD[D, widths[j]] + + for k in values: + assert_equal(S(-1), BitMask.is_negative(S(-k[]))) + assert_equal(S(0), BitMask.is_negative(S(k[]))) + + +def test_is_true(): + alias dtypes = ( + DType.int8, + DType.int16, + DType.int32, + DType.int64, + DType.index, + DType.uint8, + DType.uint16, + DType.uint32, + DType.uint64, + ) + alias widths = (1, 2, 4, 8) + + @parameter + for i in range(len(dtypes)): + alias D = dtypes[i] + + @parameter + for j in range(len(widths)): + alias w = widths[j] + alias B = SIMD[DType.bool, w] + assert_equal(SIMD[D, w](-1), BitMask.is_true[D](B(True))) + assert_equal(SIMD[D, w](0), BitMask.is_true[D](B(False))) + + +def test_compare(): + alias dtypes = ( + DType.int8, + DType.int16, + DType.int32, + DType.int64, + DType.index, + ) + alias widths = (1, 2, 4, 8) + + @parameter + for i in range(len(dtypes)): + alias D = dtypes[i] + var last_value = 2 ** (bitwidthof[D]() - 1) - 1 + var values = List(1, 2, last_value - 1, last_value) + + @parameter + for j in range(len(widths)): + alias S = SIMD[D, widths[j]] + + for k in values: + var s_k = S(k[]) + var s_k_1 = S(k[] - 1) + assert_equal(S(-1), BitMask.is_true[D](s_k == s_k)) + assert_equal(S(-1), BitMask.is_true[D](-s_k == -s_k)) + assert_equal(S(-1), BitMask.is_true[D](s_k != s_k_1)) + assert_equal(S(-1), BitMask.is_true[D](-s_k != s_k_1)) + assert_equal(S(-1), BitMask.is_true[D](s_k > s_k_1)) + assert_equal(S(-1), BitMask.is_true[D](s_k_1 > -s_k)) + assert_equal(S(-1), BitMask.is_true[D](-s_k >= -s_k)) + assert_equal(S(-1), BitMask.is_true[D](-s_k < s_k_1)) + assert_equal(S(-1), BitMask.is_true[D](-s_k <= -s_k)) + + +def main(): + test_is_negative() + test_is_true() + test_compare()