diff --git a/docs/src/index.md b/docs/src/index.md index 9c86ba9d..6ff57b09 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -26,6 +26,9 @@ This package defines: * [`Open`](@ref), indicating the endpoint value of the interval is not included * [`Unbounded`](@ref), indicating the endpoint value is effectively infinite +You can create your own `AbstractInterval` type by following the interface specification +provided below. + ## Sets A single interval can be used to represent a contiguous set within a domain but cannot be @@ -262,6 +265,26 @@ julia> plot(intervals, 1:11) In the plot, inclusive boundaries are marked with a vertical bar, whereas exclusive boundaries just end. +## Interval Interface + +To create your own `AbstractInterval` type you need to define how to get the lower and upper +bound of the interval and how to construct new intervals of your type from these bounds. +All other functions defined in this package should work correctly if you do this. + +Construction of new intervals requires knowing how to intervals types interact: e.g. if you compute the intersection of two intervals, one that's of your new interval type and one that's an Interval, what should be returned? The default factory always constructs `Interval` objects regardless of the input types. To define a different behavior, you create an `Interval.AbstractFactory` subtype. + +```@docs +Intervals.AbstractFactory +Intervals.factory +Intervals.interval_type +``` + +To define how to get the lower and upper bound of your type, define a method for `Intervals.LowerBound` and `Intervals.UpperBound`; these are intended to be constructors and are expected to return objects of their respective type. + +```@docs +Intervals.UpperBound(::AbstractInterval, ::Intervals.AbstractFactory) +Intervals.LowerBound(::AbstractInterval, ::Intervals.AbstractFactory) +``` ## API diff --git a/src/endpoint.jl b/src/endpoint.jl index eb91b844..91a3456e 100644 --- a/src/endpoint.jl +++ b/src/endpoint.jl @@ -33,8 +33,10 @@ const RightEndpoint{T,B} = Endpoint{T, Right, B} where {T,B <: Bound} LeftEndpoint{B}(ep::T) where {T,B} = LeftEndpoint{T,B}(ep) RightEndpoint{B}(ep::T) where {T,B} = RightEndpoint{T,B}(ep) -LeftEndpoint(i::AbstractInterval{T,L,R}) where {T,L,R} = LeftEndpoint{T,L}(L !== Unbounded ? first(i) : nothing) -RightEndpoint(i::AbstractInterval{T,L,R}) where {T,L,R} = RightEndpoint{T,R}(R !== Unbounded ? last(i) : nothing) +LeftEndpoint(i::AbstractInterval) = LeftEndpoint(i, interval_factory(i)) +RightEndpoint(i::AbstractInterval) = RightEndpoint(i, interval_factory(i)) +LeftEndpoint(i::AbstractInterval{T,L,R}, ::DefaultFactory) where {T,L,R} = LeftEndpoint{T,L}(L !== Unbounded ? first(i) : nothing) +RightEndpoint(i::AbstractInterval{T,L,R}, ::DefaultFactory) where {T,L,R} = RightEndpoint{T,R}(R !== Unbounded ? last(i) : nothing) endpoint(x::Endpoint) = isbounded(x) ? x.endpoint : nothing bound_type(x::Endpoint{T,D,B}) where {T,D,B} = B diff --git a/src/interval.jl b/src/interval.jl index cb18c8f9..2daab084 100644 --- a/src/interval.jl +++ b/src/interval.jl @@ -389,30 +389,24 @@ function contiguous(a::AbstractInterval, b::AbstractInterval) ) end -function Base.intersect(a::AbstractInterval{T}, b::AbstractInterval{T}) where T - !overlaps(a,b) && return Interval{T}() - left = max(LeftEndpoint(a), LeftEndpoint(b)) - right = min(RightEndpoint(a), RightEndpoint(b)) - - return Interval{T}(left, right) -end - -function Base.intersect(a::AbstractInterval{S}, b::AbstractInterval{T}) where {S,T} - !overlaps(a, b) && return Interval{promote_type(S, T)}() - left = max(LeftEndpoint(a), LeftEndpoint(b)) - right = min(RightEndpoint(a), RightEndpoint(b)) +function Base.intersect(a::AbstractInterval, b::AbstractInterval) + tracking = endpoint_tracking(a, b) + !overlaps(a,b) && return tointerval(tracking) + left = max(LeftEndpoint(a, tracking), LeftEndpoint(b, tracking)) + right = min(RightEndpoint(a, tracking), RightEndpoint(b, tracking)) - return Interval(left, right) + return tracking(left, right) end function Base.merge(a::AbstractInterval, b::AbstractInterval) + tracking = endpoint_tracking(a, b) if !overlaps(a, b) && !contiguous(a, b) throw(ArgumentError("$a and $b are neither overlapping or contiguous.")) end - left = min(LeftEndpoint(a), LeftEndpoint(b)) - right = max(RightEndpoint(a), RightEndpoint(b)) - return Interval(left, right) + left = min(LeftEndpoint(a, tracking), LeftEndpoint(b, tracking)) + right = max(RightEndpoint(a, tracking), RightEndpoint(b, tracking)) + return tracking(left, right) end ##### ROUNDING ##### diff --git a/src/interval_sets.jl b/src/interval_sets.jl index 16211696..a53f3625 100644 --- a/src/interval_sets.jl +++ b/src/interval_sets.jl @@ -101,60 +101,163 @@ const AbstractIntervals = Union{AbstractInterval, IntervalSet} # TrackEachEndpoint tracks endpoints on a case-by-case basis # computing closed/open with boolean flags -abstract type EndpointTracking; end -struct TrackEachEndpoint <: EndpointTracking; end +abstract type AbstractEndpointTracking{T}; end +struct TrackEachEndpoint{T, F} <: AbstractEndpointTracking{T} + factory::F +end +LeftEndpoint(interval::AbstractInterval, tr::AbstractEndpointTracking) = LeftEndpoint(interval, tr.factory) +RightEndpoint(interval::AbstractInterval, tr::AbstractEndpointTracking) = RightEndpoint(interval, tr.factory) + # TrackLeftOpen and TrackRightOpen track the endpoints statically: if the # intervals to be merged are all left open (or all right open), the resulting # output will always be all left open (or all right open). -abstract type TrackStatically{T} <: EndpointTracking; end -struct TrackLeftOpen{T} <: TrackStatically{T}; end -struct TrackRightOpen{T} <: TrackStatically{T}; end +abstract type AbstractTrackStatically{T} <: AbstractEndpointTracking{T}; end +struct TrackLeftOpen{T, F} <: AbstractTrackStatically{T} + factory::F +end +struct TrackRightOpen{T, F} <: AbstractTrackStatically{T} + factory::F +end function endpoint_tracking( ::Type{<:AbstractInterval{T,Open,Closed}}, ::Type{<:AbstractInterval{U,Open,Closed}}, + factory ) where {T,U} W = promote_type(T, U) - return TrackLeftOpen{W}() + return TrackLeftOpen{W}(factory) end function endpoint_tracking( ::Type{<:AbstractInterval{T,Closed,Open}}, ::Type{<:AbstractInterval{U,Closed,Open}}, + factory ) where {T,U} W = promote_type(T, U) - return TrackRightOpen{W}() + return TrackRightOpen{W}(factory) +end +function endpoint_tracking( + ::Type{<:AbstractInterval{T}}, + ::Type{<:AbstractInterval{U}}, + factory +) + W = promote_type(T, U) + return TrackEachEndpoint{W}(factory) end function endpoint_tracking( ::Type{<:AbstractInterval}, ::Type{<:AbstractInterval}, + factory ) - return TrackEachEndpoint() + return TrackEachEndpoint{Any}(factory) end -endpoint_tracking(a::IntervalSet, b::IntervalSet) = endpoint_tracking(eltype(a), eltype(b)) -endpoint_tracking(a::AbstractInterval, b::AbstractInterval) = endpoint_tracking(typeof(a), typeof(b)) -endpoint_tracking(a::AbstractVector, b::AbstractVector) = endpoint_tracking(eltype(a), eltype(b)) +endpoint_tracking(a::IntervalSet, b::IntervalSet) = endpoint_tracking(eltype(a), eltype(b), factory(a, b)) +endpoint_tracking(a::AbstractInterval, b::AbstractInterval) = endpoint_tracking(typeof(a), typeof(b), factory(a, b)) +endpoint_tracking(a::AbstractVector, b::AbstractVector) = endpoint_tracking(eltype(a), eltype(b), factory(a, b)) + +# When we split intervals into endpoints we also need a way to construct new intervals from +# the split endpoints. This is done using an factory. The default one just calls `Interval` +# on the endpoints. These is some additional tracking of types that needs to be handled to +# indicate the proper eltype for arrays of the constructed intervals. The methods are setup +# to minimize the number of methods that need to be overloaded to define a new type of +# factory (for a differnet concrete interval type). +""" + struct AbstractFactory <: Function; end + +A callable object for constructing new intervals. Concrete types that are `isa +AbstractFactory` must define two methods: one dispatching on a `LowerBound` and `UpperBound` +object to construct an interval from these bounds, and the other dispatching on an empty +argument list, which should construct an empty interval (or throw an error if this is not +possible). They should also define a method of [`interval_type`](@ref) +""" +struct AbstractFactory <: Function; end +struct DefaultFactory{T,D} <: AbstractFactory; end +(tr::AbstractEndpointTracking)(a, b) = tr.factory(a, b) +(::DefaultFactory{T})(a::AbstractEndpoint, b::AbstractEndpoint) where T = Interval{T}(a, b) +(::DefaultFactory{T})() where T = Interval{T,Closed,Open}(zero(T), zero(T)) +(::DefaultFactory{T,:LeftOpen})() where T = Interval{T,Open,Closed}(zero(T), zero(T)) + +""" + factory(a::AbstractInterval) + +Returns an `AbstractFactory` object defining how a new interval should be constructed from +the bounds of `a`. Defaults to the internal type `Intervals.DefaultFactory` which will +construct new objects as the `Interval` type. +""" +factory(a::AbstractInterval{T}) = DefaultFactory{T, :RightOpen}() +factory(a::AbstractInterval{T,Open,Closed}) = DefaultFactory{T, :LeftOpen}() + +""" + factory(a::AbstractInterval, b::AbstractInterval) + +Returns an `AbstractFactory` object defining how a new interval should be constructed when +pulling from the bounds of `a` and `b` (e.g. the intersection of `a` and `b`). Fallback +methods default to the internal type `Intervals.DefaultFactory` which will construct new +objects with the `Interval` type. +""" +factory(x::AbstractInterval{T}, y::AbstractInterval{U}) where {T,U} = DefaultFactory{promote_type(T, U), :RightOpen}() +factory(x::AbstractInterval{T,Open,Closed}, y::AbstractInterval{U,Open,Closed}) where {T,U} = DefaultFactory{promote_type{T, U}, :LeftOpen}() +factory(x::AbstractInterval, y::AbstractInterval) = DefaultFactory{Any, :RightOpen}() + +""" + factory(a::AbstractVector{<:AbstractInterval}, b::Abstract{<:AbstractInterval}) + +Returns an `AbstractFactory` object defining how a new interval should be constructed when +pulling from the bounds of all intervals in `a` and `b` (e.g. the intersection of `a[1]` and +`b[2]`). Fallback methods default to the internal type `Intervals.DefaultFactory` which will +construct new objects with the `Interval` type. +""" +factory(x::AbstractVector{<:AbstractInterval{T}}, y::AbstractVector{<:AbstractInterval{U}}) where {T,U} = DefaultFactory{promote_type(T,U), :RightOpen}() +factory(x::AbstractVector{<:AbstractInterval{T,Open,Closed}}, y::AbstractVector{<:AbstractInterval{U,Open,Closed}}) where {T,U} = DefaultFactory{promote_type(T,U), :LeftOpen}() +factory(x::AbstractVector{<:AbstractInterval}, y::AbstractVector{<:AbstractInterval}) DefaultFactory{Any, :RightOpen}() +factory(a::IntervalSet, b::IntervalSet) = factory(a.items, b.items) + +""" + interval_type(x::AbstractFactory, L::Type{<:Bound}, U::Type{<:Bound}) + +Given a factory and the lower and upper boundings (`Closed/Open/Unbounded`) return the +expected type of the interval. If your specific factory only constructs intervals with a +fixed boundedness you can safely implement a single-argument method of this function, +since there is a fall back that drops the last two arguments. +""" +interval_type(x, L, R) = interval_type(x) +interval_type(::DefaultFactory{T}) where {T} = Interval{T} +interval_type(::DefaultFactory{T}, L, R) where {T} = Interval{T,L,R} # track: run a thunk, but only if we are tracking endpoints dynamically track(fn::Function, ::TrackEachEndpoint, args...) = fn(args...) -track(_, tracking::TrackStatically, args...) = tracking +track(_, tracking::AbstractTrackStatically, args...) = tracking -endpoint_type(::TrackEachEndpoint) = Endpoint +function endpoint_type(::AbstractInterval{T,L,R}) where {T,L,R} + return Union{LeftEndpoint{T,L}, RightEndpoint{T,R}} +end +# if eltype is not concrete give an abstract endpoint type; note that if we were to dispatch +# on AbstractVector{<:AbstractInterval} here would enforce a concrete eltype +function endpoint_type(x::AbstractVector) + eltype(x) isa AbstractInterval || error("Expected vector of intervals") + return Endpoint +end +# if eltype is concrete, give a union of concrete endpoint types +function endpoint_type(::AbstractVector{I}) where {T,L,R,I <: AbstractInterval{T,L,R}} + return Union{LeftEndpoint{T,L}, RightEndpoint{T,R}} +end +endpoint_type(::TrackEachEndpoint{T}) where T = Endpoint{T} endpoint_type(::TrackLeftOpen{T}) where T = Union{LeftEndpoint{T,Open}, RightEndpoint{T, Closed}} endpoint_type(::TrackRightOpen{T}) where T = Union{LeftEndpoint{T,Closed}, RightEndpoint{T, Open}} -interval_type(::TrackEachEndpoint) = Interval -interval_type(::TrackLeftOpen{T}) where T = Interval{T, Open, Closed} -interval_type(::TrackRightOpen{T}) where T = Interval{T, Closed, Open} +interval_type(track::TrackEachEndpoint) = interval_type(track.factory) +interval_type(track::TrackRightOpen{T}) where {T} = interval_type(track.factory, Closed, Open) +interval_type(track::TrackLeftOpen{T}) where {T} = interval_type(track.factory, Open, Closed) # `unbunch/bunch`: the generic operation used to implement all set operations operates on a # series of sorted endpoints (see `mergesets` below); this first requires that # all vectors of sets be represented by their endpoints. The functions unbunch # and bunch convert between an interval and an endpoint representation -function unbunch(interval::AbstractInterval, tracking::EndpointTracking; lt=isless) - return endpoint_type(tracking)[LeftEndpoint(interval), RightEndpoint(interval)] +function unbunch(interval::AbstractInterval, tracking::AbstractEndpointTracking; lt=isless) + return endpoint_type(interval)[LeftEndpoint(interval, tracking), + RightEndpoint(interval, tracking)] end -function unbunch(intervals::IntervalSet, tracking::EndpointTracking; kwargs...) +function unbunch(intervals::IntervalSet, tracking::AbstractEndpointTracking; kwargs...) return unbunch(convert(Vector, intervals), tracking; kwargs...) end unbunch_by_fn(_) = identity @@ -163,7 +266,7 @@ function unbunch( AbstractVector{<:AbstractInterval}, Base.Iterators.Enumerate{<:Union{AbstractIntervals, AbstractVector{<:AbstractInterval}}} }, - tracking::EndpointTracking; + tracking::AbstractEndpointTracking; lt=isless, ) by = unbunch_by_fn(intervals) @@ -176,7 +279,8 @@ end unbunch_by_fn(::Base.Iterators.Enumerate) = last function unbunch((i, interval)::Tuple, tracking; lt=isless) eltype = Tuple{Int, endpoint_type(tracking)} - return eltype[(i, LeftEndpoint(interval)), (i, RightEndpoint(interval))] + return eltype[(i, LeftEndpoint(interval, tracking)), + (i, RightEndpoint(interval, tracking))] end function unbunch(a::Union{AbstractVector{<:AbstractInterval}, AbstractIntervals}, @@ -188,17 +292,12 @@ function unbunch(a::Union{AbstractVector{<:AbstractInterval}, AbstractIntervals} end # represent a sequence of endpoints as a sequence of one or more intervals -function bunch(endpoints, tracking) +function bunch(endpoints::AbstractVector, tracking) @assert iseven(length(endpoints)) isempty(endpoints) && return IntervalSet(interval_type(tracking)[]) - res = map(Iterators.partition(endpoints, 2)) do pair - return Interval(pair..., tracking) - end + res = map(x -> tracking(x...), Iterators.partition(endpoints, 2)) return IntervalSet(res) end -Interval(a::Endpoint, b::Endpoint, ::TrackEachEndpoint) = Interval(a, b) -Interval(a::Endpoint, b::Endpoint, ::TrackLeftOpen{T}) where T = Interval{T,Open,Closed}(a.endpoint, b.endpoint) -Interval(a::Endpoint, b::Endpoint, ::TrackRightOpen{T}) where T = Interval{T,Closed,Open}(a.endpoint, b.endpoint) # the sentinel endpoint reduces the number of edgecases # we have to deal with when comparing endpoints during a merge @@ -260,7 +359,9 @@ isleft(::RightEndpoint) = false # open left ((1, 1]) then all resulting endpoints will follow the same pattern. function mergesets(op, x, y) - x_, y_, tracking = unbunch(union(x), union(y)) + x, y = union(x), union(y) + tracking = endpoint_tracking(x, y) + x_, y_, tracking = unbunch(x, y) return mergesets_helper(op, x_, y_, tracking) end length_(x::AbstractInterval) = 1 @@ -344,25 +445,25 @@ function mergesets_helper(op, x, y, endpoint_tracking) end # abuts: true if unioning the two endpoints would lead to a single interval (e.g. (0 1] ∪ (1, 2))) abuts(::SentinelEndpoint, _, _) = false -abuts(oldstop::Endpoint, newstart, ::TrackStatically) = oldstop.endpoint == newstart.endpoint -function abuts(oldstop::Endpoint, newstart, ::TrackEachEndpoint) - return oldstop.endpoint == newstart.endpoint && (isclosed(oldstop) || isclosed(newstart)) +abuts(oldstop::Endpoint, newstart, ::AbstractTrackStatically) = endpoint(oldstop) == endpoint(newstart) +function abuts(oldstop::Endpoint, newstart, ::AbstractTrackDynamically) + return endpoint(oldstop) == endpoint(newstart) && (isclosed(oldstop) || isclosed(newstart)) end # empty_interval: true if the given left and right endpoints would create an empty interval empty_interval(::SentinelEndpoint, _, _) = false # sentinal means there was no starting endpoint; there is thus no interval, and so no empty interval -empty_interval(start, stop, ::TrackStatically) = start.endpoint == stop.endpoint -empty_interval(start, stop, ::TrackEachEndpoint) = start > stop +empty_interval(start, stop, ::AbstractTrackStatically) = endpoint(start) == endpoint(stop) +empty_interval(start, stop, ::AbstractTrackDynamically) = start > stop # the below methods create a left or a right endpoint from the endpoint t: note # that t might not be the same type of endpoint (e.g. -# `left_endpoint(RightEndpoint(...))` is perfectly valid). `mergesets` may +# `LeftEndpoint(RightEndpoint(...), tracking)` is perfectly valid). `mergesets` may # change which side of an interval an endpoint is on. left_endpoint(t::Endpoint{T}, ::Type{B}) where {T, B <: Bound} = LeftEndpoint{T, B}(endpoint(t)) right_endpoint(t::Endpoint{T}, ::Type{B}) where {T, B <: Bound} = RightEndpoint{T, B}(endpoint(t)) -left_endpoint(t, ::TrackLeftOpen{T}) where T = LeftEndpoint{T,Open}(endpoint(t)) -left_endpoint(t, ::TrackRightOpen{T}) where T = LeftEndpoint{T,Closed}(endpoint(t)) -right_endpoint(t, ::TrackLeftOpen{T}) where T = RightEndpoint{T,Closed}(endpoint(t)) -right_endpoint(t, ::TrackRightOpen{T}) where T = RightEndpoint{T,Open}(endpoint(t)) +left_endpoint(t::Endpoint, ::TrackLeftOpen{T}) where T = LeftEndpoint{T,Open}(endpoint(t)) +left_endpoint(t::Endpoint, ::TrackRightOpen{T}) where T = LeftEndpoint{T,Closed}(endpoint(t)) +right_endpoint(t::Endpoint, ::TrackLeftOpen{T}) where T = RightEndpoint{T,Closed}(endpoint(t)) +right_endpoint(t::Endpoint, ::TrackRightOpen{T}) where T = RightEndpoint{T,Open}(endpoint(t)) ##### Multi-interval Set Operations ##### @@ -457,7 +558,7 @@ Base.in(x, y::IntervalSet) = any(Base.Fix1(in, x), y.items) # order edges so that closed boundaries are on the outside: e.g. [( )] intersection_order(x::Endpoint) = isleft(x) ? !isclosed(x) : isclosed(x) -intersection_isless_fn(::TrackStatically) = isless +intersection_isless_fn(::AbstractTrackStatically) = isless function intersection_isless_fn(::TrackEachEndpoint) function (x,y) if isequal(x, y) diff --git a/test/sets.jl b/test/sets.jl index e28e24fc..b663b2bc 100644 --- a/test/sets.jl +++ b/test/sets.jl @@ -40,6 +40,24 @@ end # interval-bound point set) @test intersect([1..2, 2..3, 3..4, 4..5], [2..3, 3..4]) == [2..3, 3..4] + # verify that the internal representation of type stable interval sets + # makes use of union splitting + @test ==(eltype(Intervals.unbunch([Interval{Open,Open}(1, 2), + Interval{Open,Open}(3,4)])), + Union{Intervals.LeftEndpoint{Int, Open}, Intervals.RightEndpoint{Int, Open}}) + @test ==(eltype(Intervals.unbunch([Interval{Closed,Closed}(1, 2), + Interval{Closed,Closed}(3,4)])), + Union{Intervals.LeftEndpoint{Int, Closed}, Intervals.RightEndpoint{Int, Closed}}) + @test ==(eltype(Intervals.unbunch([Interval{Open,Closed}(1, 2), + Interval{Open,Closed}(3,4)])), + Union{Intervals.LeftEndpoint{Int, Open}, Intervals.RightEndpoint{Int, Closed}}) + @test ==(eltype(Intervals.unbunch([Interval{Closed,Open}(1, 2), + Interval{Closed,Open}(3,4)])), + Union{Intervals.LeftEndpoint{Int, Closed}, Intervals.RightEndpoint{Int, Open}}) + @test ==(eltype(Intervals.unbunch([Interval{Closed,Open}(1, 2), + Interval{Open,Closed}(3,4)])), + Intervals.Endpoint{Int}) + # verify that elements are in / subsets of interval sets @test 2 ∈ IntervalSet([1..3, 5..10]) @test 0 ∉ IntervalSet([1..3, 5..10])