diff --git a/Project.toml b/Project.toml index 6ea7814..4b44887 100644 --- a/Project.toml +++ b/Project.toml @@ -8,10 +8,11 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "BenchmarkTools", "JET", "SafeTestsets", "Test"] +test = ["Aqua", "JET", "OrderedCollections", "PrettyTables", "SafeTestsets", "Test"] diff --git a/src/UnrolledUtilities.jl b/src/UnrolledUtilities.jl index 9191479..1563d17 100644 --- a/src/UnrolledUtilities.jl +++ b/src/UnrolledUtilities.jl @@ -1,5 +1,159 @@ +""" + UnrolledUtilities + +A collection of generated functions in which all loops are unrolled and inlined. + +The functions exported by this module are +- `unrolled_any(f, itr)`: similar to `any` +- `unrolled_all(f, itr)`: similar to `all` +- `unrolled_foreach(f, itrs...)`: similar to `foreach` +- `unrolled_map(f, itrs...)`: similar to `map` +- `unrolled_reduce(op, itr; [init])`: similar to `reduce` +- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce` +- `unrolled_zip(itrs...)`: similar to `zip` +- `unrolled_in(item, itr)`: similar to `in` +- `unrolled_unique(itr)`: similar to `unique` +- `unrolled_filter(f, itr)`: similar to `filter` +- `unrolled_split(f, itr)`: similar to `(filter(f, itr), filter(!f, itr))`, but + without duplicate calls to `f` +- `unrolled_flatten(itr)`: similar to `Iterators.flatten` +- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap` +- `unrolled_product(itrs...)`: similar to `Iterators.product` +- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take`, but with the + second argument wrapped in a `Val` +- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop`, but with the + second argument wrapped in a `Val` + +These functions are guaranteed to be type-stable whenever they are given +iterators with inferrable lengths and element types, including when +- the iterators have nonuniform element types (with the exception of `map`, all + of the corresponding functions from `Base` encounter type-instabilities and + allocations when this is the case) +- the iterators have many elements (e.g., more than 32, which is the threshold + at which `map` becomes type-unstable for `Tuple`s) +- `f` and/or `op` recursively call the function to which they is passed, with an + arbitrarily large recursion depth (e.g., if `f` calls `map(f, itrs)`, it will + be type-unstable when the recursion depth exceeds 3, but this will not be the + case with `unrolled_map`) + +Moreover, these functions are very likely to be optimized out through constant +propagation when the iterators have singleton element types (and when the result +of calling `f` and/or `op` on these elements is inferrable). +""" module UnrolledUtilities -# TODO: Add source code. +export unrolled_any, + unrolled_all, + unrolled_foreach, + unrolled_map, + unrolled_reduce, + unrolled_mapreduce, + unrolled_zip, + unrolled_in, + unrolled_unique, + unrolled_filter, + unrolled_split, + unrolled_flatten, + unrolled_flatmap, + unrolled_product, + unrolled_take, + unrolled_drop + +inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types) +# We could also add support for statically-sized iterators that are not Tuples. + +f_exprs(itr_type) = (:(f(itr[$n])) for n in 1:inferred_length(itr_type)) +@inline @generated unrolled_any(f, itr) = Expr(:||, f_exprs(itr)...) +@inline @generated unrolled_all(f, itr) = Expr(:&&, f_exprs(itr)...) + +function zipped_f_exprs(itr_types) + L = length(itr_types) + L == 0 && error("unrolled functions need at least one iterator as input") + N = minimum(inferred_length, itr_types) + return (:(f($((:(itrs[$l][$n]) for l in 1:L)...))) for n in 1:N) +end +@inline @generated unrolled_foreach(f, itrs...) = + Expr(:block, zipped_f_exprs(itrs)...) +@inline @generated unrolled_map(f, itrs...) = + Expr(:tuple, zipped_f_exprs(itrs)...) + +function nested_op_expr(itr_type) + N = inferred_length(itr_type) + N == 0 && error("unrolled_reduce needs an `init` value for empty iterators") + item_exprs = (:(itr[$n]) for n in 1:N) + return reduce((expr1, expr2) -> :(op($expr1, $expr2)), item_exprs) +end +@inline @generated unrolled_reduce_without_init(op, itr) = nested_op_expr(itr) + +struct NoInit end +@inline unrolled_reduce(op, itr; init = NoInit()) = + unrolled_reduce_without_init(op, init isa NoInit ? itr : (init, itr...)) + +@inline unrolled_mapreduce(f, op, itrs...; init_kwarg...) = + unrolled_reduce(op, unrolled_map(f, itrs...); init_kwarg...) + +@inline unrolled_zip(itrs...) = unrolled_map(tuple, itrs...) + +@inline unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr) +# Using === instead of == or isequal improves type stability for singletons. + +@inline unrolled_unique(itr) = + unrolled_reduce(itr; init = ()) do unique_items, item + @inline + unrolled_in(item, unique_items) ? unique_items : (unique_items..., item) + end + +@inline unrolled_filter(f, itr) = + unrolled_reduce(itr; init = ()) do filtered_items, item + @inline + f(item) ? (filtered_items..., item) : filtered_items + end + +@inline unrolled_split(f, itr) = + unrolled_reduce(itr; init = ((), ())) do (f_items, not_f_items), item + @inline + f(item) ? ((f_items..., item), not_f_items) : + (f_items, (not_f_items..., item)) + end + +@inline unrolled_flatten(itr) = + unrolled_reduce((item1, item2) -> (item1..., item2...), itr; init = ()) + +@inline unrolled_flatmap(f, itrs...) = + unrolled_flatten(unrolled_map(f, itrs...)) + +@inline unrolled_product(itrs...) = + unrolled_reduce(itrs; init = ((),)) do product_itr, itr + @inline + unrolled_flatmap(itr) do item + @inline + unrolled_map(product_tuple -> (product_tuple..., item), product_itr) + end + end + +@inline unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N)) +@inline unrolled_drop(itr, ::Val{N}) where {N} = + ntuple(i -> itr[N + i], Val(length(itr) - N)) +# When its second argument is a Val, ntuple is unrolled via Base.@ntuple. + +@static if hasfield(Method, :recursion_relation) + # Remove recursion limits for functions whose arguments are also functions. + for func in ( + unrolled_any, + unrolled_all, + unrolled_foreach, + unrolled_map, + unrolled_reduce_without_init, + unrolled_reduce, + unrolled_mapreduce, + unrolled_filter, + unrolled_split, + unrolled_flatmap, + ) + for method in methods(func) + method.recursion_relation = (_...) -> true + end + end +end end diff --git a/test/runtests.jl b/test/runtests.jl index 2d1aac4..be00ae4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using SafeTestsets #! format: off +@safetestset "Test and Analyze" begin @time include("test_and_analyze.jl") end @safetestset "Aqua" begin @time include("aqua.jl") end #! format: on diff --git a/test/test_and_analyze.jl b/test/test_and_analyze.jl new file mode 100644 index 0000000..74ba3d8 --- /dev/null +++ b/test/test_and_analyze.jl @@ -0,0 +1,397 @@ +using Test +using JET +using OrderedCollections +using PrettyTables + +using UnrolledUtilities + +measurements_dict = OrderedDict() + +function code_instance(f, args...) + available_methods = methods(f, Tuple{map(typeof, args)...}) + @assert length(available_methods) == 1 + (; specializations) = available_methods[1] + specTypes = Tuple{typeof(f), map(typeof, args)...} + return if specializations isa Core.MethodInstance + @assert specializations.specTypes == specTypes + specializations.cache + else + matching_specialization_indices = + findall(specializations) do specialization + !isnothing(specialization) && + specialization.specTypes == specTypes + end + @assert length(matching_specialization_indices) == 1 + specializations[matching_specialization_indices[1]].cache + end +end + +macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str) + @assert Meta.isexpr(args_expr, :tuple) + arg_names = args_expr.args + @assert all(arg_name -> arg_name isa Symbol, arg_names) + args = map(esc, arg_names) + unrolled_expr_str = + replace(string(unrolled_expr), r"\s*#=.+=#" => "", r"\s+" => ' ') + reference_expr_str = + replace(string(reference_expr), r"\s*#=.+=#" => "", r"\s+" => ' ') + expr_info_str = + length(args) == 1 ? "$unrolled_expr_str with 1 iterator that contains" : + "$unrolled_expr_str with $(length(args)) iterators that each contain" + quote + @info "Testing $($expr_info_str) $($(esc(contents_info_str)))" + + unrolled_func($(arg_names...)) = $unrolled_expr + reference_func($(arg_names...)) = $reference_expr + + # Test for correctness. + @test unrolled_func($(args...)) == reference_func($(args...)) + + unrolled_func_and_nothing($(arg_names...)) = ($unrolled_expr; nothing) + reference_func_and_nothing($(arg_names...)) = ($reference_expr; nothing) + + unrolled_func_and_nothing($(args...)) # Run once to compile. + reference_func_and_nothing($(args...)) + + # Test for allocations. + @test (@allocated unrolled_func_and_nothing($(args...))) == 0 + is_reference_non_allocating = + (@allocated reference_func_and_nothing($(args...))) == 0 + + # Test for type-stability. + @test_opt unrolled_func($(args...)) + is_reference_stable = + isempty(JET.get_reports(@report_opt reference_func($(args...)))) + + unrolled_instance = code_instance(unrolled_func, $(args...)) + reference_instance = code_instance(reference_func, $(args...)) + + # Test for constant propagation. + is_unrolled_const = isdefined(unrolled_instance, :rettype_const) + Base.issingletontype(typeof(($(args...),))) && @test is_unrolled_const + is_reference_const = isdefined(reference_instance, :rettype_const) + + arg_name_strs = ($(map(string, arg_names)...),) + arg_names_str = join(arg_name_strs, ", ") + arg_definition_strs = + map((name, value) -> "$name = $value", arg_name_strs, ($(args...),)) + arg_definitions_str = join(arg_definition_strs, '\n') + unrolled_command_str = """ + using UnrolledUtilities + unrolled_func($arg_names_str) = $($unrolled_expr_str) + $arg_definitions_str + stats = @timed unrolled_func($arg_names_str) + print(stats.time, ',', stats.bytes) + """ + reference_command_str = """ + reference_func($arg_names_str) = $($reference_expr_str) + $arg_definitions_str + stats = @timed reference_func($arg_names_str) + print(stats.time, ',', stats.bytes) + """ + + # Get the compilation times and allocations. + buffer1 = IOBuffer() + run(pipeline(`julia --project -e $unrolled_command_str`, buffer1)) + unrolled_time, unrolled_allocs = + parse.((Float64, Int), split(String(take!(buffer1)), ',')) + close(buffer1) + buffer2 = IOBuffer() + run(pipeline(`julia --project -e $reference_command_str`, buffer2)) + reference_time, reference_allocs = + parse.((Float64, Int), split(String(take!(buffer2)), ',')) + close(buffer2) + + # Record all of the measurements. + unrolled_performance_str = + is_unrolled_const ? "constant" : "type-stable" + reference_performance_str = if !is_reference_non_allocating + "allocating" + elseif !is_reference_stable + "type-unstable" + else + is_reference_const ? "constant" : "type-stable" + end + time_ratio = unrolled_time / reference_time + time_ratio_str = if time_ratio >= 1.5 + "$(round(Int, time_ratio)) times slower" + elseif inv(time_ratio) >= 1.5 + "$(round(Int, inv(time_ratio))) times faster" + else + "similar" + end + allocs_ratio = unrolled_allocs / reference_allocs + allocs_ratio_str = if allocs_ratio >= 1.5 + "$(round(Int, allocs_ratio)) times more" + elseif inv(allocs_ratio) >= 1.5 + "$(round(Int, inv(allocs_ratio))) times less" + else + "similar" + end + measurement_key = ($unrolled_expr_str, $reference_expr_str) + measurement_entry = ( + $(esc(contents_info_str)), + unrolled_performance_str, + reference_performance_str, + time_ratio_str, + allocs_ratio_str, + ) + if measurement_key in keys(measurements_dict) + push!(measurements_dict[measurement_key], measurement_entry) + else + measurements_dict[measurement_key] = [measurement_entry] + end + end +end + +@testset "empty iterators" begin + itr = () + str = "nothing" + @test_unrolled (itr,) unrolled_any(error, itr) any(error, itr) str + @test_unrolled (itr,) unrolled_all(error, itr) all(error, itr) str + @test_unrolled (itr,) unrolled_foreach(error, itr) foreach(error, itr) str + @test_unrolled (itr,) unrolled_map(error, itr, itr) map(error, itr, itr) str + @test_unrolled( + (itr,), + unrolled_reduce(error, itr; init = 0), + reduce(error, itr; init = 0), + str, + ) +end + +for n in (1, 10, 33), all_identical in (n == 1 ? (true,) : (true, false)) + itr1 = ntuple(i -> ntuple(Val, all_identical ? 0 : (i - 1) % 7), n) + itr2 = ntuple(i -> ntuple(Val, all_identical ? 1 : (i - 1) % 7 + 1), n) + itr3 = ntuple(i -> ntuple(identity, all_identical ? 1 : (i - 1) % 7 + 1), n) + if n == 1 + str1 = "1 empty tuple" + str2 = "1 nonempty singleton tuple" + str3 = "1 nonempty non-singleton tuple" + str12 = "1 singleton tuple" + str23 = "1 nonempty tuple" + str123 = "1 tuple" + elseif all_identical + str1 = "$n empty tuples" + str2 = "$n identical nonempty singleton tuples" + str3 = "$n identical nonempty non-singleton tuples" + str12 = "$n identical singleton tuples" + str23 = "$n identical nonempty tuples" + str123 = "$n identical tuples" + else + str1 = "$n empty and nonempty singleton tuples" + str2 = "$n nonempty singleton tuples" + str3 = "$n nonempty non-singleton tuples" + str12 = "$n singleton tuples" + str23 = "$n nonempty tuples" + str123 = "$n tuples" + end + @testset "iterators of $str123" begin + for (itr, str) in ((itr1, str1), (itr2, str2), (itr3, str3)) + @test_unrolled (itr,) unrolled_any(isempty, itr) any(isempty, itr) str + @test_unrolled (itr,) unrolled_any(!isempty, itr) any(!isempty, itr) str + + @test_unrolled (itr,) unrolled_all(isempty, itr) all(isempty, itr) str + @test_unrolled (itr,) unrolled_all(!isempty, itr) all(!isempty, itr) str + + @test_unrolled( + (itr,), + unrolled_foreach(x -> (@assert length(x) <= 7), itr), + foreach(x -> (@assert length(x) <= 7), itr), + str, + ) + + @test_unrolled (itr,) unrolled_map(length, itr) map(length, itr) str + + @test_unrolled (itr,) unrolled_reduce(tuple, itr) reduce(tuple, itr) str + @test_unrolled( + (itr,), + unrolled_reduce(tuple, itr; init = ()), + reduce(tuple, itr; init = ()), + str, + ) + + @test_unrolled( + (itr,), + unrolled_mapreduce(length, +, itr), + mapreduce(length, +, itr), + str, + ) + @test_unrolled( + (itr,), + unrolled_mapreduce(length, +, itr; init = 0), + mapreduce(length, +, itr; init = 0), + str, + ) + + @test_unrolled (itr,) unrolled_zip(itr) Tuple(zip(itr)) str + + @test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str + @test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str + @test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str + + if Base.issingletontype(typeof(itr)) + @test_unrolled (itr,) unrolled_unique(itr) Tuple(unique(itr)) str + end + + @test_unrolled( + (itr,), + unrolled_filter(!isempty, itr), + filter(!isempty, itr), + str, + ) + + @test_unrolled( + (itr,), + unrolled_split(isempty, itr), + (filter(isempty, itr), filter(!isempty, itr)), + str, + ) + + @test_unrolled( + (itr,), + unrolled_flatten(itr), + Tuple(Iterators.flatten(itr)), + str, + ) + + @test_unrolled( + (itr,), + unrolled_flatmap(reverse, itr), + Tuple(Iterators.flatmap(reverse, itr)), + str, + ) + + @test_unrolled( + (itr,), + unrolled_product(itr), + Tuple(Iterators.product(itr)), + str, + ) + + if n > 1 + @test_unrolled( + (itr,), + unrolled_take(itr, Val(7)), + itr[1:7], + str, + ) + @test_unrolled( + (itr,), + unrolled_drop(itr, Val(7)), + itr[8:end], + str, + ) + end + end + + @test_unrolled( + (itr3,), + unrolled_any(x -> unrolled_reduce(+, x) > 7, itr3), + any(x -> reduce(+, x) > 7, itr3), + str3, + ) + + @test_unrolled( + (itr3,), + unrolled_mapreduce(x -> unrolled_reduce(+, x), max, itr3), + mapreduce(x -> reduce(+, x), max, itr3), + str3, + ) + + @test_unrolled( + (itr1, itr2), + unrolled_foreach( + (x1, x2) -> (@assert length(x1) < length(x2)), + itr1, + itr2, + ), + foreach((x1, x2) -> (@assert length(x1) < length(x2)), itr1, itr2), + str12, + ) + @test_unrolled( + (itr2, itr3), + unrolled_foreach( + (x2, x3) -> (@assert x2 == unrolled_map(Val, x3)), + itr2, + itr3, + ), + foreach((x2, x3) -> (@assert x2 == map(Val, x3)), itr2, itr3), + str23, + ) + + @test_unrolled( + (itr1, itr2), + unrolled_zip(itr1, itr2), + Tuple(zip(itr1, itr2)), + str12, + ) + @test_unrolled( + (itr1, itr2, itr3), + unrolled_zip(itr1, itr2, itr3), + Tuple(zip(itr1, itr2, itr3)), + str123, + ) + + @test_unrolled( + (itr1, itr2), + unrolled_product(itr1, itr2), + Tuple(Iterators.product(itr1, itr2)), + str12, + ) + if n <= 10 # This can take several minutes to compile when n is large. + @test_unrolled( + (itr1, itr2, itr3), + unrolled_product(itr1, itr2, itr3), + Tuple(Iterators.product(itr1, itr2, itr3)), + str123, + ) + end + end +end + +table_data = mapreduce(vcat, collect(measurements_dict)) do (key, entries) + stack(entry -> (key..., entry...), entries; dims = 1) +end +header_line1 = [ + "Unrolled Expression", + "Reference Expression", + "Iterator Contents", + "Unrolled Performance", + "Reference Performance", + "Compilation Time", + "Compilation Memory", +] +header_line2 = + ["", "", "", "", "", "(Unrolled vs. Reference)", "(Unrolled vs. Reference)"] +better_performance_but_harder_to_compile = + Highlighter(crayon"blue") do data, i, j + data[i, 4] != data[i, 5] && + (endswith(data[i, 6], "slower") || endswith(data[i, 7], "more")) + end +better_performance = + Highlighter((data, i, j) -> data[i, 4] != data[i, 5], crayon"green") +harder_to_compile = Highlighter(crayon"red") do data, i, j + endswith(data[i, 6], "slower") || endswith(data[i, 7], "more") +end +easier_to_compile = Highlighter(crayon"magenta") do data, i, j + endswith(data[i, 6], "faster") || endswith(data[i, 7], "less") +end +no_difference = Highlighter((data, i, j) -> true, crayon"yellow") +pretty_table( + table_data; + title = "Comparison between UnrolledUtilities and Base/Base.Iterators", + header = (header_line1, header_line2), + subheader_crayon = crayon"bold", + highlighters = ( + better_performance_but_harder_to_compile, + better_performance, + harder_to_compile, + easier_to_compile, + no_difference, + ), + title_same_width_as_table = true, + title_alignment = :c, + alignment = :l, + columns_width = [45, 45, 0, 0, 0, 0, 0], + crop = :none, +)