Skip to content

Commit

Permalink
Add unrolled functions and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Mar 30, 2024
1 parent 78c2be0 commit 0a1471f
Show file tree
Hide file tree
Showing 4 changed files with 556 additions and 3 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "JET", "SafeTestsets", "Test"]
test = ["Aqua", "JET", "OrderedCollections", "PrettyTables", "SafeTestsets", "Test"]
156 changes: 155 additions & 1 deletion src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,159 @@
"""
UnrolledUtilities
A collection of generated functions in which all loops are unrolled and inlined.
The functions exported by this module are
- `unrolled_any(f, itr)`: similar to `any`
- `unrolled_all(f, itr)`: similar to `all`
- `unrolled_foreach(f, itrs...)`: similar to `foreach`
- `unrolled_map(f, itrs...)`: similar to `map`
- `unrolled_reduce(op, itr; [init])`: similar to `reduce`
- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce`
- `unrolled_zip(itrs...)`: similar to `zip`
- `unrolled_in(item, itr)`: similar to `in`
- `unrolled_unique(itr)`: similar to `unique`
- `unrolled_filter(f, itr)`: similar to `filter`
- `unrolled_split(f, itr)`: similar to `(filter(f, itr), filter(!f, itr))`, but
without duplicate calls to `f`
- `unrolled_flatten(itr)`: similar to `Iterators.flatten`
- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap`
- `unrolled_product(itrs...)`: similar to `Iterators.product`
- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take`, but with the
second argument wrapped in a `Val`
- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop`, but with the
second argument wrapped in a `Val`
These functions are guaranteed to be type-stable whenever they are given
iterators with inferrable lengths and element types, including when
- the iterators have nonuniform element types (with the exception of `map`, all
of the corresponding functions from `Base` encounter type-instabilities and
allocations when this is the case)
- the iterators have many elements (e.g., more than 32, which is the threshold
at which `map` becomes type-unstable for `Tuple`s)
- `f` and/or `op` recursively call the function to which they is passed, with an
arbitrarily large recursion depth (e.g., if `f` calls `map(f, itrs)`, it will
be type-unstable when the recursion depth exceeds 3, but this will not be the
case with `unrolled_map`)
Moreover, these functions are very likely to be optimized out through constant
propagation when the iterators have singleton element types (and when the result
of calling `f` and/or `op` on these elements is inferrable).
"""
module UnrolledUtilities

# TODO: Add source code.
export unrolled_any,
unrolled_all,
unrolled_foreach,
unrolled_map,
unrolled_reduce,
unrolled_mapreduce,
unrolled_zip,
unrolled_in,
unrolled_unique,
unrolled_filter,
unrolled_split,
unrolled_flatten,
unrolled_flatmap,
unrolled_product,
unrolled_take,
unrolled_drop

inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types)
# We could also add support for statically-sized iterators that are not Tuples.

f_exprs(itr_type) = (:(f(itr[$n])) for n in 1:inferred_length(itr_type))
@inline @generated unrolled_any(f, itr) = Expr(:||, f_exprs(itr)...)
@inline @generated unrolled_all(f, itr) = Expr(:&&, f_exprs(itr)...)

function zipped_f_exprs(itr_types)
L = length(itr_types)
L == 0 && error("unrolled functions need at least one iterator as input")
N = minimum(inferred_length, itr_types)
return (:(f($((:(itrs[$l][$n]) for l in 1:L)...))) for n in 1:N)
end
@inline @generated unrolled_foreach(f, itrs...) =
Expr(:block, zipped_f_exprs(itrs)...)
@inline @generated unrolled_map(f, itrs...) =
Expr(:tuple, zipped_f_exprs(itrs)...)

function nested_op_expr(itr_type)
N = inferred_length(itr_type)
N == 0 && error("unrolled_reduce needs an `init` value for empty iterators")
item_exprs = (:(itr[$n]) for n in 1:N)
return reduce((expr1, expr2) -> :(op($expr1, $expr2)), item_exprs)
end
@inline @generated unrolled_reduce_without_init(op, itr) = nested_op_expr(itr)

struct NoInit end
@inline unrolled_reduce(op, itr; init = NoInit()) =
unrolled_reduce_without_init(op, init isa NoInit ? itr : (init, itr...))

@inline unrolled_mapreduce(f, op, itrs...; init_kwarg...) =
unrolled_reduce(op, unrolled_map(f, itrs...); init_kwarg...)

@inline unrolled_zip(itrs...) = unrolled_map(tuple, itrs...)

@inline unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr)
# Using === instead of == or isequal improves type stability for singletons.

@inline unrolled_unique(itr) =
unrolled_reduce(itr; init = ()) do unique_items, item
@inline
unrolled_in(item, unique_items) ? unique_items : (unique_items..., item)
end

@inline unrolled_filter(f, itr) =
unrolled_reduce(itr; init = ()) do filtered_items, item
@inline
f(item) ? (filtered_items..., item) : filtered_items
end

@inline unrolled_split(f, itr) =
unrolled_reduce(itr; init = ((), ())) do (f_items, not_f_items), item
@inline
f(item) ? ((f_items..., item), not_f_items) :
(f_items, (not_f_items..., item))
end

@inline unrolled_flatten(itr) =
unrolled_reduce((item1, item2) -> (item1..., item2...), itr; init = ())

@inline unrolled_flatmap(f, itrs...) =
unrolled_flatten(unrolled_map(f, itrs...))

@inline unrolled_product(itrs...) =
unrolled_reduce(itrs; init = ((),)) do product_itr, itr
@inline
unrolled_flatmap(itr) do item
@inline
unrolled_map(product_tuple -> (product_tuple..., item), product_itr)
end
end

@inline unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N))
@inline unrolled_drop(itr, ::Val{N}) where {N} =
ntuple(i -> itr[N + i], Val(length(itr) - N))
# When its second argument is a Val, ntuple is unrolled via Base.@ntuple.

@static if hasfield(Method, :recursion_relation)
# Remove recursion limits for functions whose arguments are also functions.
for func in (
unrolled_any,
unrolled_all,
unrolled_foreach,
unrolled_map,
unrolled_reduce_without_init,
unrolled_reduce,
unrolled_mapreduce,
unrolled_filter,
unrolled_split,
unrolled_flatmap,
)
for method in methods(func)
method.recursion_relation = (_...) -> true
end
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SafeTestsets

#! format: off
@safetestset "Test and Analyze" begin @time include("test_and_analyze.jl") end
@safetestset "Aqua" begin @time include("aqua.jl") end
#! format: on
Loading

0 comments on commit 0a1471f

Please sign in to comment.