From 9bebbaaa49e5f9be115f90293bb6fdcc66bfa8a7 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 21 Nov 2024 11:53:09 -0300 Subject: [PATCH 1/4] Add Span.count() Signed-off-by: martinvuyk --- docs/changelog.md | 7 ++ stdlib/src/utils/span.mojo | 128 ++++++++++++++++++++++++++++++- stdlib/test/utils/test_span.mojo | 15 ++++ 3 files changed, 146 insertions(+), 4 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index d6c831da5b..32b4fdcdeb 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -193,6 +193,13 @@ what we publish. ([PR #3160](https://github.com/modularml/mojo/pull/3160) by [@bgreni](https://github.com/bgreni)) +- `Span` now implements a generic `.count()` method which can be passed a + function that returns a boolean SIMD vector. The function counts how many + times it returns `True` evaluating it in a vectorized manner. This works for + any `Span[Scalar[D]]` e.g. `Span[Byte]`. It also implements the Python-like + `.count(sub)` which works for any scalar value sequence. + PR [#3792](https://github.com/modularml/mojo/pull/3792) by [@martinvuyk](https://github.com/martinvuyk). + - `StringRef` now implements `split()` which can be used to split a `StringRef` into a `List[StringRef]` by a delimiter. ([PR #2705](https://github.com/modularml/mojo/pull/2705) by [@fknfilewalker](https://github.com/fknfilewalker)) diff --git a/stdlib/src/utils/span.mojo b/stdlib/src/utils/span.mojo index 3ebb316f2d..d66879d638 100644 --- a/stdlib/src/utils/span.mojo +++ b/stdlib/src/utils/span.mojo @@ -23,6 +23,7 @@ from utils import Span from collections import InlineArray from memory import Pointer, UnsafePointer from builtin.builtin_list import _lit_mut_cast +from sys.info import simdwidthof, sizeof trait AsBytes: @@ -345,8 +346,7 @@ struct Span[ return not self == rhs fn fill[origin: MutableOrigin, //](self: Span[T, origin], value: T): - """ - Fill the memory that a span references with a given value. + """Fill the memory that a span references with a given value. Parameters: origin: The inferred mutable origin of the data within the Span. @@ -358,8 +358,7 @@ struct Span[ element[] = value fn get_immutable(self) -> Span[T, _lit_mut_cast[origin, False].result]: - """ - Return an immutable version of this span. + """Return an immutable version of this span. Returns: A span covering the same elements, but without mutability. @@ -367,3 +366,124 @@ struct Span[ return Span[T, _lit_mut_cast[origin, False].result]( ptr=self._data, length=self._len ) + + fn count[D: DType, //](self: Span[Scalar[D]], sub: Span[Scalar[D]]) -> UInt: + """Return the number of non-overlapping occurrences of subsequence. + + Parameters: + D: The DType. + + Args: + sub: The subsequence. + + Returns: + The number of non-overlapping occurrences of subsequence. + """ + + if len(sub) == 1: + var s = sub.unsafe_ptr()[0] + + @parameter + fn equal_fn[w: Int](v: SIMD[D, w]) -> SIMD[DType.bool, w]: + return v == SIMD[D, w](s) + + return self.count[func=equal_fn]() + + # FIXME(#3548): this is a hack until we have Span.find(). All count + # implementations should delegate to Span.count() anyway. + return String( + StringSlice[origin]( + ptr=self.unsafe_ptr().bitcast[Byte](), + length=len(self) * sizeof[Scalar[D]](), + ) + ).count( + String( + StringSlice[origin]( + ptr=sub.unsafe_ptr().bitcast[Byte](), + length=len(sub) * sizeof[Scalar[D]](), + ) + ) + ) + + fn count[ + D: DType, //, func: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w] + ](self: Span[Scalar[D]]) -> UInt: + """Count the amount of times the function returns `True`. + + Parameters: + D: The DType. + func: The function to evaluate. + + Returns: + The amount of times the function returns `True`. + """ + + alias widths = (256, 128, 64, 32, 16, 8) + var ptr = self.unsafe_ptr() + var num_bytes = len(self) + var amnt = UInt(0) + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((num_bytes - processed) // w): + var vec = (ptr + processed).load[width=w]() + + @parameter + if w >= 256: + amnt += int(func(vec).cast[DType.uint16]().reduce_add()) + else: + amnt += int(func(vec).cast[DType.uint8]().reduce_add()) + processed += w + + for i in range(num_bytes - processed): + amnt += int(func(ptr[processed + i])) + + return amnt + + # FIXME(#2535): delete once function effects can be parametrized + fn count[ + D: DType, //, + func: fn[w: Int] (SIMD[D, w]) capturing -> SIMD[DType.bool, w], + ](self: Span[Scalar[D]]) -> UInt: + """Count the amount of times the function returns `True`. + + Parameters: + D: The DType. + func: The function to evaluate. + + Returns: + The amount of times the function returns `True`. + """ + + alias widths = (256, 128, 64, 32, 16, 8) + var ptr = self.unsafe_ptr() + var num_bytes = len(self) + var amnt = UInt(0) + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((num_bytes - processed) // w): + var vec = (ptr + processed).load[width=w]() + + @parameter + if w >= 256: + amnt += int(func(vec).cast[DType.uint16]().reduce_add()) + else: + amnt += int(func(vec).cast[DType.uint8]().reduce_add()) + processed += w + + for i in range(num_bytes - processed): + print(i, func(ptr[processed + i])) + amnt += int(func(ptr[processed + i])) + + return amnt diff --git a/stdlib/test/utils/test_span.mojo b/stdlib/test/utils/test_span.mojo index 79cd780401..c5294d30f2 100644 --- a/stdlib/test/utils/test_span.mojo +++ b/stdlib/test/utils/test_span.mojo @@ -203,6 +203,20 @@ def test_ref(): assert_true(s.as_ref() == Pointer.address_of(l.unsafe_ptr()[])) +def test_count(): + var str = String("Hello world").as_bytes() + + assert_equal(12, str.count("".as_bytes())) + assert_equal(1, str.count("Hell".as_bytes())) + assert_equal(3, str.count("l".as_bytes())) + assert_equal(1, str.count("ll".as_bytes())) + assert_equal(1, str.count("ld".as_bytes())) + assert_equal(0, str.count("universe".as_bytes())) + + assert_equal("aaaaa".as_bytes().count("a".as_bytes()), 5) + assert_equal("aaaaaa".as_bytes().count("aa".as_bytes()), 3) + + def main(): test_span_list_int() test_span_list_str() @@ -214,3 +228,4 @@ def main(): test_bool() test_fill() test_ref() + test_count() From 9f2820fd301e475515a468e535aa01e20cc97f47 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 21 Nov 2024 12:08:25 -0300 Subject: [PATCH 2/4] fix detail Signed-off-by: martinvuyk --- stdlib/src/utils/span.mojo | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stdlib/src/utils/span.mojo b/stdlib/src/utils/span.mojo index d66879d638..c29969bb1a 100644 --- a/stdlib/src/utils/span.mojo +++ b/stdlib/src/utils/span.mojo @@ -390,7 +390,7 @@ struct Span[ return self.count[func=equal_fn]() # FIXME(#3548): this is a hack until we have Span.find(). All count - # implementations should delegate to Span.count() anyway. + # implementations should delegate to Span.count() eventually. return String( StringSlice[origin]( ptr=self.unsafe_ptr().bitcast[Byte](), @@ -483,7 +483,6 @@ struct Span[ processed += w for i in range(num_bytes - processed): - print(i, func(ptr[processed + i])) amnt += int(func(ptr[processed + i])) return amnt From 3674b4d2401fd2fc00137b7318865b3a268a26ed Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 21 Nov 2024 12:21:26 -0300 Subject: [PATCH 3/4] refactor _count_utf8_continuation_bytes Signed-off-by: martinvuyk --- stdlib/src/utils/string_slice.mojo | 37 ++++++++---------------------- 1 file changed, 10 insertions(+), 27 deletions(-) diff --git a/stdlib/src/utils/string_slice.mojo b/stdlib/src/utils/string_slice.mojo index b7fe463728..410921a693 100644 --- a/stdlib/src/utils/string_slice.mojo +++ b/stdlib/src/utils/string_slice.mojo @@ -36,30 +36,16 @@ alias StaticString = StringSlice[StaticConstantOrigin] """An immutable static string slice.""" -fn _count_utf8_continuation_bytes(span: Span[Byte]) -> Int: - alias sizes = (256, 128, 64, 32, 16, 8) - var ptr = span.unsafe_ptr() - var num_bytes = len(span) - var amnt: Int = 0 - var processed = 0 - - @parameter - for i in range(len(sizes)): - alias s = sizes.get[i, Int]() - - @parameter - if simdwidthof[DType.uint8]() >= s: - var rest = num_bytes - processed - for _ in range(rest // s): - var vec = (ptr + processed).load[width=s]() - var comp = (vec & 0b1100_0000) == 0b1000_0000 - amnt += int(comp.cast[DType.uint8]().reduce_add()) - processed += s +@always_inline +fn _is_continuation_byte[ + w: Int +](vec: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]: + return (vec & 0b1100_0000) == 0b1000_0000 - for i in range(num_bytes - processed): - amnt += int((ptr[processed + i] & 0b1100_0000) == 0b1000_0000) - return amnt +@always_inline +fn _count_utf8_continuation_bytes(span: Span[Byte]) -> Int: + return span.count[func=_is_continuation_byte]() fn _unicode_codepoint_utf8_byte_length(c: Int) -> Int: @@ -76,11 +62,8 @@ fn _utf8_first_byte_sequence_length(b: Byte) -> Int: this does not work correctly if given a continuation byte.""" debug_assert( - (b & 0b1100_0000) != 0b1000_0000, - ( - "Function `_utf8_first_byte_sequence_length()` does not work" - " correctly if given a continuation byte." - ), + not _is_continuation_byte(b), + "Function does not work correctly if given a continuation byte.", ) var flipped = ~b return int(count_leading_zeros(flipped) + (flipped >> 7)) From b19cd7aeacf6a159d3f515585d829d3e778a7ad6 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 21 Nov 2024 12:24:35 -0300 Subject: [PATCH 4/4] fix detail Signed-off-by: martinvuyk --- stdlib/src/utils/string_slice.mojo | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stdlib/src/utils/string_slice.mojo b/stdlib/src/utils/string_slice.mojo index 410921a693..4cca1679b0 100644 --- a/stdlib/src/utils/string_slice.mojo +++ b/stdlib/src/utils/string_slice.mojo @@ -65,8 +65,7 @@ fn _utf8_first_byte_sequence_length(b: Byte) -> Int: not _is_continuation_byte(b), "Function does not work correctly if given a continuation byte.", ) - var flipped = ~b - return int(count_leading_zeros(flipped) + (flipped >> 7)) + return int(count_leading_zeros(~b)) + int(b < 0b1000_0000) fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int):