diff --git a/src/interval_sets.jl b/src/interval_sets.jl index 4d915c39..6a150e1a 100644 --- a/src/interval_sets.jl +++ b/src/interval_sets.jl @@ -161,13 +161,19 @@ unbunch_by_fn(_) = identity function unbunch( intervals::Union{ AbstractVector{<:AbstractInterval}, - Base.Iterators.Enumerate{<:Union{AbstractIntervals, AbstractVector{<:AbstractInterval}}} + Base.Iterators.Enumerate{<:Union{AbstractIntervals, + AbstractVector{<:Union{Missing, AbstractInterval}}}} }, tracking::EndpointTracking; lt=isless, ) by = unbunch_by_fn(intervals) - filtered = Iterators.filter(!isempty ∘ by, intervals) + filtered = Iterators.filter(intervals) do x + let interval = by(x) + ismissing(interval) && return false + return isempty(interval) + end + end isempty(filtered) && return endpoint_type(tracking)[] result = mapreduce(x -> unbunch(x, tracking), vcat, filtered) return sort!(result; lt, by) @@ -176,11 +182,12 @@ end unbunch_by_fn(::Base.Iterators.Enumerate) = last function unbunch((i, interval)::Tuple, tracking; lt=isless) eltype = Tuple{Int, endpoint_type(tracking)} + ismissing(interval) && return missing return eltype[(i, LeftEndpoint(interval)), (i, RightEndpoint(interval))] end -function unbunch(a::Union{AbstractVector{<:AbstractInterval}, AbstractIntervals}, - b::Union{AbstractVector{<:AbstractInterval}, AbstractIntervals}; kwargs...) +function unbunch(a::Union{AbstractVector{<:Union{Missing, AbstractInterval}}, AbstractIntervals}, + b::Union{AbstractVector{<:Union{Missing, AbstractInterval}}, AbstractIntervals}; kwargs...) tracking = endpoint_tracking(a, b) a_ = unbunch(a, tracking; kwargs...) b_ = unbunch(b, tracking; kwargs...) @@ -469,16 +476,27 @@ function intersection_isless_fn(::TrackEachEndpoint) end """ - find_intersections( - x::AbstractVector{<:AbstractInterval}, - y::AbstractVector{<:AbstractInterval} - ) + find_intersections(x, y; missings=:error) Returns a `Vector{Vector{Int}}` where the value at index `i` gives the indices to all -intervals in `y` that intersect with `x[i]`. +intervals in `y` that intersect with `x[i]`. Calls collect on the arguments if they +aren't already arrays. + +When `missings` == `:error` (the default), `find_intersections` will throw an error +if any values are missing in `x` or `y`. If set to `:skip`, missing values will +be skipped, thus behaving the same way as an interval that does not overlap with any +other interval. """ -find_intersections(x, y) = find_intersections(vcat(x), vcat(y)) -function find_intersections(x::AbstractVector{<:AbstractInterval}, y::AbstractVector{<:AbstractInterval}) +find_intersections(x, y; kwargs...) = find_intersections_(collect(x), collect(y); kwargs...) +function find_intersections_( + x::AbstractVector{<:Union{Missing, AbstractInterval}}, + y::AbstractVector{<:Union{Missing, AbstractInterval}}; + missings=:error +) + (isempty(x) || isempty(y)) && return Vector{Int}[] + if missings ∈ (:error, :skip) + throw(ArgumentError("`missings` should be set to `:error` or `:skip`")) + end tracking = endpoint_tracking(x, y) lt = intersection_isless_fn(tracking) x_endpoints = unbunch(enumerate(x), tracking; lt) diff --git a/test/sets.jl b/test/sets.jl index e28e24fc..bcc85ea9 100644 --- a/test/sets.jl +++ b/test/sets.jl @@ -51,6 +51,12 @@ end @test !issubset(11, IntervalSet([1..3, 5..10])) @test issubset(2, IntervalSet([1.0 .. 3.0, 5.0 .. 10.0])) + @test isempty(find_intersections([], [])) + @test_throws ArgumentError find_intersections([1], [2]) + @test isempty(find_intersections(Interval[], Interval[])) + @test isempty(find_intersections([1..3], Interval[])) + @test isempty(find_intersections(Interval[], [1..3])) + function testsets(a, b) @test area(a ∪ b) ≤ area(myunion(a)) + area(myunion(b)) @test area(setdiff(a, b)) ≤ area(myunion(a))