Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add Span.count() #3792

Open
wants to merge 4 commits into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ what we publish.
([PR #3160](https://github.com/modularml/mojo/pull/3160) by
[@bgreni](https://github.com/bgreni))

- `Span` now implements a generic `.count()` method which can be passed a
function that returns a boolean SIMD vector. The function counts how many
times it returns `True` evaluating it in a vectorized manner. This works for
any `Span[Scalar[D]]` e.g. `Span[Byte]`. It also implements the Python-like
`.count(sub)` which works for any scalar value sequence.
PR [#3792](https://github.com/modularml/mojo/pull/3792) by [@martinvuyk](https://github.com/martinvuyk).

- `StringRef` now implements `split()` which can be used to split a
`StringRef` into a `List[StringRef]` by a delimiter.
([PR #2705](https://github.com/modularml/mojo/pull/2705) by [@fknfilewalker](https://github.com/fknfilewalker))
Expand Down
127 changes: 123 additions & 4 deletions stdlib/src/utils/span.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from utils import Span
from collections import InlineArray
from memory import Pointer, UnsafePointer
from builtin.builtin_list import _lit_mut_cast
from sys.info import simdwidthof, sizeof


trait AsBytes:
Expand Down Expand Up @@ -345,8 +346,7 @@ struct Span[
return not self == rhs

fn fill[origin: MutableOrigin, //](self: Span[T, origin], value: T):
"""
Fill the memory that a span references with a given value.
"""Fill the memory that a span references with a given value.

Parameters:
origin: The inferred mutable origin of the data within the Span.
Expand All @@ -358,12 +358,131 @@ struct Span[
element[] = value

fn get_immutable(self) -> Span[T, _lit_mut_cast[origin, False].result]:
"""
Return an immutable version of this span.
"""Return an immutable version of this span.

Returns:
A span covering the same elements, but without mutability.
"""
return Span[T, _lit_mut_cast[origin, False].result](
ptr=self._data, length=self._len
)

fn count[D: DType, //](self: Span[Scalar[D]], sub: Span[Scalar[D]]) -> UInt:
"""Return the number of non-overlapping occurrences of subsequence.

Parameters:
D: The DType.

Args:
sub: The subsequence.

Returns:
The number of non-overlapping occurrences of subsequence.
"""

if len(sub) == 1:
var s = sub.unsafe_ptr()[0]

@parameter
fn equal_fn[w: Int](v: SIMD[D, w]) -> SIMD[DType.bool, w]:
return v == SIMD[D, w](s)

return self.count[func=equal_fn]()

# FIXME(#3548): this is a hack until we have Span.find(). All count
# implementations should delegate to Span.count() eventually.
return String(
StringSlice[origin](
ptr=self.unsafe_ptr().bitcast[Byte](),
length=len(self) * sizeof[Scalar[D]](),
)
).count(
String(
StringSlice[origin](
ptr=sub.unsafe_ptr().bitcast[Byte](),
length=len(sub) * sizeof[Scalar[D]](),
)
)
)

fn count[
D: DType, //, func: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w]
](self: Span[Scalar[D]]) -> UInt:
"""Count the amount of times the function returns `True`.

Parameters:
D: The DType.
func: The function to evaluate.

Returns:
The amount of times the function returns `True`.
"""

alias widths = (256, 128, 64, 32, 16, 8)
var ptr = self.unsafe_ptr()
var num_bytes = len(self)
var amnt = UInt(0)
var processed = 0

@parameter
for i in range(len(widths)):
alias w = widths.get[i, Int]()

@parameter
if simdwidthof[D]() >= w:
for _ in range((num_bytes - processed) // w):
var vec = (ptr + processed).load[width=w]()

@parameter
if w >= 256:
amnt += int(func(vec).cast[DType.uint16]().reduce_add())
else:
amnt += int(func(vec).cast[DType.uint8]().reduce_add())
processed += w

for i in range(num_bytes - processed):
amnt += int(func(ptr[processed + i]))

return amnt

# FIXME(#2535): delete once function effects can be parametrized
fn count[
D: DType, //,
func: fn[w: Int] (SIMD[D, w]) capturing -> SIMD[DType.bool, w],
](self: Span[Scalar[D]]) -> UInt:
"""Count the amount of times the function returns `True`.

Parameters:
D: The DType.
func: The function to evaluate.

Returns:
The amount of times the function returns `True`.
"""

alias widths = (256, 128, 64, 32, 16, 8)
var ptr = self.unsafe_ptr()
var num_bytes = len(self)
var amnt = UInt(0)
var processed = 0

@parameter
for i in range(len(widths)):
alias w = widths.get[i, Int]()

@parameter
if simdwidthof[D]() >= w:
for _ in range((num_bytes - processed) // w):
var vec = (ptr + processed).load[width=w]()

@parameter
if w >= 256:
amnt += int(func(vec).cast[DType.uint16]().reduce_add())
else:
amnt += int(func(vec).cast[DType.uint8]().reduce_add())
processed += w

for i in range(num_bytes - processed):
amnt += int(func(ptr[processed + i]))

return amnt
40 changes: 11 additions & 29 deletions stdlib/src/utils/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,16 @@ alias StaticString = StringSlice[StaticConstantOrigin]
"""An immutable static string slice."""


fn _count_utf8_continuation_bytes(span: Span[Byte]) -> Int:
alias sizes = (256, 128, 64, 32, 16, 8)
var ptr = span.unsafe_ptr()
var num_bytes = len(span)
var amnt: Int = 0
var processed = 0

@parameter
for i in range(len(sizes)):
alias s = sizes.get[i, Int]()

@parameter
if simdwidthof[DType.uint8]() >= s:
var rest = num_bytes - processed
for _ in range(rest // s):
var vec = (ptr + processed).load[width=s]()
var comp = (vec & 0b1100_0000) == 0b1000_0000
amnt += int(comp.cast[DType.uint8]().reduce_add())
processed += s
@always_inline
fn _is_continuation_byte[
w: Int
](vec: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
return (vec & 0b1100_0000) == 0b1000_0000

for i in range(num_bytes - processed):
amnt += int((ptr[processed + i] & 0b1100_0000) == 0b1000_0000)

return amnt
@always_inline
fn _count_utf8_continuation_bytes(span: Span[Byte]) -> Int:
return span.count[func=_is_continuation_byte]()


fn _unicode_codepoint_utf8_byte_length(c: Int) -> Int:
Expand All @@ -76,14 +62,10 @@ fn _utf8_first_byte_sequence_length(b: Byte) -> Int:
this does not work correctly if given a continuation byte."""

debug_assert(
(b & 0b1100_0000) != 0b1000_0000,
(
"Function `_utf8_first_byte_sequence_length()` does not work"
" correctly if given a continuation byte."
),
not _is_continuation_byte(b),
"Function does not work correctly if given a continuation byte.",
)
var flipped = ~b
return int(count_leading_zeros(flipped) + (flipped >> 7))
return int(count_leading_zeros(~b)) + int(b < 0b1000_0000)


fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int):
Expand Down
15 changes: 15 additions & 0 deletions stdlib/test/utils/test_span.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,20 @@ def test_ref():
assert_true(s.as_ref() == Pointer.address_of(l.unsafe_ptr()[]))


def test_count():
var str = String("Hello world").as_bytes()

assert_equal(12, str.count("".as_bytes()))
assert_equal(1, str.count("Hell".as_bytes()))
assert_equal(3, str.count("l".as_bytes()))
assert_equal(1, str.count("ll".as_bytes()))
assert_equal(1, str.count("ld".as_bytes()))
assert_equal(0, str.count("universe".as_bytes()))

assert_equal("aaaaa".as_bytes().count("a".as_bytes()), 5)
assert_equal("aaaaaa".as_bytes().count("aa".as_bytes()), 3)


def main():
test_span_list_int()
test_span_list_str()
Expand All @@ -214,3 +228,4 @@ def main():
test_bool()
test_fill()
test_ref()
test_count()