Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add BitMask utils #3886

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
102 changes: 102 additions & 0 deletions mojo/stdlib/src/bit/mask.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Provides functions for bit masks.

You can import these APIs from the `bit` package. For example:

```mojo
from bit.mask import BitMask
```
"""

from os import abort
from sys.info import bitwidthof


struct BitMask:
"""Utils for building bitmasks."""

@always_inline
@staticmethod
fn is_negative(value: Int) -> Int:
"""Get a bitmask of whether the value is negative.

Args:
value: The value to check.

Returns:
A bitmask filled with `1` if the value is negative, filled with `0`
otherwise.
"""
return Int(Self.is_negative(Scalar[DType.index](value)))

@always_inline
@staticmethod
fn is_negative[D: DType](value: SIMD[D, _]) -> __type_of(value):
"""Get a bitmask of whether the value is negative.

Parameters:
D: The DType.

Args:
value: The value to check.

Returns:
A bitmask filled with `1` if the value is negative, filled with `0`
otherwise.
"""
constrained[
D.is_integral() and D.is_signed(),
"This function is for signed integral types.",
]()
return value >> (bitwidthof[D]() - 1)

@always_inline
@staticmethod
fn is_true[
D: DType, size: Int = 1
](value: SIMD[DType.bool, size]) -> SIMD[D, size]:
"""Get a bitmask of whether the value is `True`.

Parameters:
D: The DType.
size: The size of the SIMD vector.

Args:
value: The value to check.

Returns:
A bitmask filled with `1` if the value is `True`, filled with `0`
otherwise.
"""
return (-(value.cast[DType.int8]())).cast[D]()

@always_inline
@staticmethod
fn is_false[
D: DType, size: Int = 1
](value: SIMD[DType.bool, size]) -> SIMD[D, size]:
"""Get a bitmask of whether the value is `False`.

Parameters:
D: The DType.
size: The size of the SIMD vector.

Args:
value: The value to check.

Returns:
A bitmask filled with `1` if the value is `False`, filled with `0`
otherwise.
"""
return Self.is_true[D](~value)
108 changes: 108 additions & 0 deletions mojo/stdlib/test/bit/test_mask.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# RUN: %bare-mojo %s

from testing import assert_equal
from bit.mask import BitMask
from sys.info import bitwidthof


def test_is_negative():
alias dtypes = (
DType.int8,
DType.int16,
DType.int32,
DType.int64,
DType.index,
)
alias widths = (1, 2, 4, 8)

@parameter
for i in range(len(dtypes)):
alias D = dtypes[i]
var last_value = 2 ** (bitwidthof[D]() - 1) - 1
var values = List(1, 2, last_value - 1, last_value)

@parameter
for j in range(len(widths)):
alias S = SIMD[D, widths[j]]

for k in values:
assert_equal(S(-1), BitMask.is_negative(S(-k[])))
assert_equal(S(0), BitMask.is_negative(S(k[])))


def test_is_true():
alias dtypes = (
DType.int8,
DType.int16,
DType.int32,
DType.int64,
DType.index,
DType.uint8,
DType.uint16,
DType.uint32,
DType.uint64,
)
alias widths = (1, 2, 4, 8)

@parameter
for i in range(len(dtypes)):
alias D = dtypes[i]

@parameter
for j in range(len(widths)):
alias w = widths[j]
alias B = SIMD[DType.bool, w]
assert_equal(SIMD[D, w](-1), BitMask.is_true[D](B(True)))
assert_equal(SIMD[D, w](0), BitMask.is_true[D](B(False)))


def test_compare():
alias dtypes = (
DType.int8,
DType.int16,
DType.int32,
DType.int64,
DType.index,
)
alias widths = (1, 2, 4, 8)

@parameter
for i in range(len(dtypes)):
alias D = dtypes[i]
var last_value = 2 ** (bitwidthof[D]() - 1) - 1
var values = List(1, 2, last_value - 1, last_value)

@parameter
for j in range(len(widths)):
alias S = SIMD[D, widths[j]]

for k in values:
var s_k = S(k[])
var s_k_1 = S(k[] - 1)
assert_equal(S(-1), BitMask.is_true[D](s_k == s_k))
assert_equal(S(-1), BitMask.is_true[D](-s_k == -s_k))
assert_equal(S(-1), BitMask.is_true[D](s_k != s_k_1))
assert_equal(S(-1), BitMask.is_true[D](-s_k != s_k_1))
assert_equal(S(-1), BitMask.is_true[D](s_k > s_k_1))
assert_equal(S(-1), BitMask.is_true[D](s_k_1 > -s_k))
assert_equal(S(-1), BitMask.is_true[D](-s_k >= -s_k))
assert_equal(S(-1), BitMask.is_true[D](-s_k < s_k_1))
assert_equal(S(-1), BitMask.is_true[D](-s_k <= -s_k))


def main():
test_is_negative()
test_is_true()
test_compare()