diff --git a/src/enumerable/enumerable_drop.jl b/src/enumerable/enumerable_drop.jl index 4ee0e7c..cef74fb 100644 --- a/src/enumerable/enumerable_drop.jl +++ b/src/enumerable/enumerable_drop.jl @@ -1,23 +1,22 @@ -struct EnumerableDrop{T,S} <: Enumerable +struct EnumerableDrop{S} <: Enumerable source::S n::Int end function drop(source::Enumerable, n::Integer) - T = eltype(source) S = typeof(source) - return EnumerableDrop{T,S}(source, Int(n)) + return EnumerableDrop{S}(source, Int(n)) end -Base.iteratorsize(::Type{EnumerableDrop{T,S}}) where {T,S} = Base.iteratorsize(S) in (Base.HasLength(), Base.HasShape()) ? Base.HasLength() : Base.SizeUnknown() +Base.iteratorsize(::Type{EnumerableDrop{S}}) where {S} = Base.iteratorsize(S) in (Base.HasLength(), Base.HasShape()) ? Base.HasLength() : Base.SizeUnknown() -Base.eltype(iter::EnumerableDrop{T,S}) where {T,S} = T +Base.iteratoreltype(::Type{EnumerableDrop{S}}) where {S} = Base.iteratoreltype(S) -Base.eltype(::Type{EnumerableDrop{T,S}}) where {T,S} = T +Base.eltype(::Type{EnumerableDrop{S}}) where {S} = eltype(S) -Base.length(iter::EnumerableDrop{T,S}) where {T,S} = max(length(iter.source)-iter.n,0) +Base.length(iter::EnumerableDrop{S}) where {S} = max(length(iter.source)-iter.n,0) -function Base.start(iter::EnumerableDrop{T,S}) where {T,S} +function Base.start(iter::EnumerableDrop{S}) where {S} source_state = start(iter.source) for i in 1:iter.n if done(iter.source, source_state) @@ -29,10 +28,10 @@ function Base.start(iter::EnumerableDrop{T,S}) where {T,S} return source_state end -function Base.next(iter::EnumerableDrop{T,S}, s) where {T,S} +function Base.next(iter::EnumerableDrop{S}, s) where {S} return next(iter.source, s) end -function Base.done(iter::EnumerableDrop{T,S}, state) where {T,S} +function Base.done(iter::EnumerableDrop{S}, state) where {S} return done(iter.source, state) end diff --git a/src/enumerable/enumerable_filter.jl b/src/enumerable/enumerable_filter.jl index 7646dda..7f770da 100644 --- a/src/enumerable/enumerable_filter.jl +++ b/src/enumerable/enumerable_filter.jl @@ -1,41 +1,34 @@ -# T is the type of the elements produced by this iterator -struct EnumerableFilter{T,S,Q<:Function} <: Enumerable +# This is the HasEltype() version +struct EnumerableFilterHasEltype{T,S,Q<:Function} <: Enumerable source::S filter::Q end -Base.eltype(iter::EnumerableFilter{T,S,Q}) where {T,S,Q} = T +Base.eltype(iter::EnumerableFilterHasEltype{T,S,Q}) where {T,S,Q} = T -Base.eltype(iter::Type{EnumerableFilter{T,S,Q}}) where {T,S,Q} = T +Base.eltype(iter::Type{EnumerableFilterHasEltype{T,S,Q}}) where {T,S,Q} = T -struct EnumerableFilterState{T,S} +struct EnumerableFilterHasEltypeState{T,S} done::Bool next_value::Nullable{T} source_state::S end -function filter(source::Enumerable, filter_func::Function, filter_expr::Expr) - T = eltype(source) - S = typeof(source) - Q = typeof(filter_func) - return EnumerableFilter{T,S,Q}(source, filter_func) -end - -function Base.start(iter::EnumerableFilter{T,S,Q}) where {T,S,Q} +function Base.start(iter::EnumerableFilterHasEltype{T,S,Q}) where {T,S,Q} s = start(iter.source) while !done(iter.source, s) v,t = next(iter.source, s) if iter.filter(v) - return EnumerableFilterState(false, Nullable(v), t) + return EnumerableFilterHasEltypeState(false, Nullable(v), t) end s = t end # The s we return here is fake, just to make sure we # return something of the right type - return EnumerableFilterState(true, Nullable{T}(), s) + return EnumerableFilterHasEltypeState(true, Nullable{T}(), s) end -function Base.next(iter::EnumerableFilter{T,S,Q}, state) where {T,S,Q} +function Base.next(iter::EnumerableFilterHasEltype{T,S,Q}, state) where {T,S,Q} v = get(state.next_value) s = state.source_state while !done(iter.source,s) @@ -44,14 +37,71 @@ function Base.next(iter::EnumerableFilter{T,S,Q}, state) where {T,S,Q} t = temp[2] if iter.filter(w)::Bool temp2 = Nullable(w) - new_state = EnumerableFilterState(false, temp2, t) + new_state = EnumerableFilterHasEltypeState(false, temp2, t) return v, new_state end s=t end # The s we return here is fake, just to make sure we # return something of the right type - v, EnumerableFilterState(true,Nullable{T}(), s) + v, EnumerableFilterHasEltypeState(true,Nullable{T}(), s) +end + +Base.done(f::EnumerableFilterHasEltype{T,S,Q}, state) where {T,S,Q} = state.done + +# This is the EltypeUnknown() version + +struct EnumerableFilterEltypeUnknown{S,Q<:Function} <: Enumerable + source::S + filter::Q +end + +Base.iteratoreltype(::Type{EnumerableFilterEltypeUnknown{S,Q}}) where {S,Q} = Base.EltypeUnknown() + +function Base.start(iter::EnumerableFilterEltypeUnknown{S,Q}) where {S,Q} + s = start(iter.source) + while !done(iter.source, s) + v,t = next(iter.source, s) + if iter.filter(v) + return (false, v, t) + end + s = t + end + return (true, ) end -Base.done(f::EnumerableFilter{T,S,Q}, state) where {T,S,Q} = state.done +function Base.next(iter::EnumerableFilterEltypeUnknown{S,Q}, state) where {S,Q} + v = state[2] + s = state[3] + while !done(iter.source,s) + temp = next(iter.source,s) + w = temp[1] + t = temp[2] + if iter.filter(w)::Bool + return v, (false, w, t) + end + s=t + end + v, (true, v, s) +end + +Base.done(f::EnumerableFilterEltypeUnknown{S,Q}, state) where {S,Q} = state[1] + +# Implementation of the query operator + +function _filter(source::Enumerable, f::Function, f_expr::Expr, ::Base.EltypeUnknown) + S = typeof(source) + Q = typeof(f) + return EnumerableFilterEltypeUnknown{S,Q}(source, f) +end + +function _filter(source::Enumerable, f::Function, f_expr::Expr, ::Base.HasEltype) + T = eltype(source) + S = typeof(source) + Q = typeof(f) + return EnumerableFilterHasEltype{T,S,Q}(source, f) +end + +function filter(source::T, filter_func::Function, filter_expr::Expr) where {T<:Enumerable} + return _filter(source, filter_func, filter_expr, Base.iteratoreltype(T)) +end diff --git a/src/enumerable/enumerable_map.jl b/src/enumerable/enumerable_map.jl index 3c6b3a2..cc1ad7b 100644 --- a/src/enumerable/enumerable_map.jl +++ b/src/enumerable/enumerable_map.jl @@ -1,37 +1,84 @@ -struct EnumerableMap{T, S, Q<:Function} <: Enumerable +# This is the HasEltype() version + +struct EnumerableMapHasEltype{T, S, Q<:Function} <: Enumerable source::S f::Q end -Base.iteratorsize(::Type{EnumerableMap{T,S,Q}}) where {T,S,Q} = Base.iteratorsize(S) in (Base.HasLength(), Base.HasShape()) ? Base.HasLength() : Base.iteratorsize(S) +Base.iteratorsize(::Type{EnumerableMapHasEltype{T,S,Q}}) where {T,S,Q} = Base.iteratorsize(S) in (Base.HasLength(), Base.HasShape()) ? Base.HasLength() : Base.iteratorsize(S) -Base.eltype(iter::EnumerableMap{T,S,Q}) where {T,S,Q} = T +Base.eltype(iter::Type{EnumerableMapHasEltype{T,S,Q}}) where {T,S,Q} = T -Base.eltype(iter::Type{EnumerableMap{T,S,Q}}) where {T,S,Q} = T +Base.length(iter::EnumerableMapHasEltype{T,S,Q}) where {T,S,Q} = length(iter.source) -Base.length(iter::EnumerableMap{T,S,Q}) where {T,S,Q} = length(iter.source) +function Base.start(iter::EnumerableMapHasEltype{T,S,Q}) where {T,S,Q} + s = start(iter.source) + return s +end -function map(source::Enumerable, f::Function, f_expr::Expr) - TS = eltype(source) - T = Base._return_type(f, Tuple{TS,}) - S = typeof(source) - Q = typeof(f) - return EnumerableMap{T,S,Q}(source, f) +function Base.next(iter::EnumerableMapHasEltype{T,S,Q}, s) where {T,S,Q} + x = next(iter.source, s) + v = x[1] + s_new = x[2] + v_new = iter.f(v)::T + return v_new, s_new end -function Base.start(iter::EnumerableMap{T,S,Q}) where {T,S,Q} +function Base.done(iter::EnumerableMapHasEltype{T,S,Q}, state) where {T,S,Q} + return done(iter.source, state) +end + +# This is the EltypeUnknown() version + +struct EnumerableMapEltypeUnknown{S, Q<:Function} <: Enumerable + source::S + f::Q +end + +Base.iteratorsize(::Type{EnumerableMapEltypeUnknown{S,Q}}) where {S,Q} = Base.iteratorsize(S) in (Base.HasLength(), Base.HasShape()) ? Base.HasLength() : Base.iteratorsize(S) + +Base.iteratoreltype(::Type{EnumerableMapEltypeUnknown{S,Q}}) where {S,Q} = Base.EltypeUnknown() + +Base.length(iter::EnumerableMapEltypeUnknown) = length(iter.source) + +function Base.start(iter::EnumerableMapEltypeUnknown) s = start(iter.source) return s end -function Base.next(iter::EnumerableMap{T,S,Q}, s) where {T,S,Q} +function Base.next(iter::EnumerableMapEltypeUnknown, s) x = next(iter.source, s) v = x[1] s_new = x[2] - v_new = iter.f(v)::T + v_new = iter.f(v) return v_new, s_new end -function Base.done(iter::EnumerableMap{T,S,Q}, state) where {T,S,Q} +function Base.done(iter::EnumerableMapEltypeUnknown, state) return done(iter.source, state) end + +# Implementation of the query operator + +function _map(source::Enumerable, f::Function, f_expr::Expr, ::Base.EltypeUnknown) + S = typeof(source) + Q = typeof(f) + return EnumerableMapEltypeUnknown{S,Q}(source, f) +end + +function _map(source::Enumerable, f::Function, f_expr::Expr, ::Base.HasEltype) + TS = eltype(source) + T = Base._return_type(f, Tuple{TS,}) + if isleaftype(T) + S = typeof(source) + Q = typeof(f) + return EnumerableMapHasEltype{T,S,Q}(source, f) + else + _map(source, f, f_expr, Base.EltypeUnknown()) + end +end + +function map(source::T, f::Function, f_expr::Expr) where {T<:Enumerable} + return _map(source, f, f_expr, Base.iteratoreltype(T)) +end + diff --git a/src/enumerable/enumerable_take.jl b/src/enumerable/enumerable_take.jl index aabe22f..b91dd7e 100644 --- a/src/enumerable/enumerable_take.jl +++ b/src/enumerable/enumerable_take.jl @@ -1,27 +1,26 @@ -struct EnumerableTake{T,S} <: Enumerable +struct EnumerableTake{S} <: Enumerable source::S n::Int end function take(source::Enumerable, n::Integer) - T = eltype(source) S = typeof(source) - return EnumerableTake{T,S}(source, Int(n)) + return EnumerableTake{S}(source, Int(n)) end -Base.iteratorsize(::Type{EnumerableTake{T,S}}) where {T,S} = Base.iteratorsize(S) in (Base.HasLength(), Base.HasShape()) ? Base.HasLength() : Base.SizeUnknown() +Base.iteratorsize(::Type{EnumerableTake{S}}) where {S} = Base.iteratorsize(S) in (Base.HasLength(), Base.HasShape()) ? Base.HasLength() : Base.SizeUnknown() -Base.eltype(iter::EnumerableTake{T,S}) where {T,S} = T +Base.iteratoreltype(::Type{EnumerableTake{S}}) where {S} = Base.iteratoreltype(S) -Base.eltype(::Type{EnumerableTake{T,S}}) where {T,S} = T +Base.eltype(::Type{EnumerableTake{S}}) where {S} = eltype(S) -Base.length(iter::EnumerableTake{T,S}) where {T,S} = min(length(iter.source),iter.n) +Base.length(iter::EnumerableTake{S}) where {S} = min(length(iter.source),iter.n) -function Base.start(iter::EnumerableTake{T,S}) where {T,S} +function Base.start(iter::EnumerableTake{S}) where {S} return iter.n, start(iter.source) end -function Base.next(iter::EnumerableTake{T,S}, s) where {T,S} +function Base.next(iter::EnumerableTake{S}, s) where {S} n, source_state = s x = next(iter.source, source_state) v = x[1] @@ -29,7 +28,7 @@ function Base.next(iter::EnumerableTake{T,S}, s) where {T,S} return v, (n-1, source_new) end -function Base.done(iter::EnumerableTake{T,S}, state) where {T,S} +function Base.done(iter::EnumerableTake{S}, state) where {S} n, source_state = state return n<=0 || done(iter.source, source_state) end diff --git a/src/source_iterable.jl b/src/source_iterable.jl index 3212b17..69de559 100644 --- a/src/source_iterable.jl +++ b/src/source_iterable.jl @@ -1,35 +1,34 @@ -struct EnumerableIterable{T,S} <: Enumerable +struct EnumerableIterable{S} <: Enumerable source::S end function query(source) IteratorInterfaceExtensions.isiterable(source) || error() typed_source = IteratorInterfaceExtensions.getiterator(source) - T = eltype(typed_source) S = typeof(typed_source) - source_enumerable = EnumerableIterable{T,S}(typed_source) + source_enumerable = EnumerableIterable{S}(typed_source) return source_enumerable end -Base.iteratorsize(::Type{EnumerableIterable{T,S}}) where {T,S} = Base.iteratorsize(S) == Base.HasShape() ? Base.HasLength() : Base.iteratorsize(S) +Base.iteratorsize(::Type{EnumerableIterable{S}}) where {S} = Base.iteratorsize(S) == Base.HasShape() ? Base.HasLength() : Base.iteratorsize(S) -Base.eltype(iter::EnumerableIterable{T,S}) where {T,S} = T +Base.iteratoreltype(::Type{EnumerableIterable{S}}) where {S} = Base.iteratoreltype(S) -Base.eltype(iter::Type{EnumerableIterable{T,S}}) where {T,S} = T +Base.eltype(::Type{EnumerableIterable{S}}) where {S} = eltype(S) -Base.length(iter::EnumerableIterable{T,S}) where {T,S} = length(iter.source) +Base.length(iter::EnumerableIterable{S}) where {S} = length(iter.source) -function Base.start(iter::EnumerableIterable{T,S}) where {T,S} +function Base.start(iter::EnumerableIterable{S}) where {S} return start(iter.source) end -@inline function Base.next(iter::EnumerableIterable{T,S}, state) where {T,S} +@inline function Base.next(iter::EnumerableIterable{S}, state) where {S} return next(iter.source, state) end -function Base.done(iter::EnumerableIterable{T,S}, state) where {T,S} +function Base.done(iter::EnumerableIterable{S}, state) where {S} return done(iter.source, state) end diff --git a/test/runtests.jl b/test/runtests.jl index 84f0e6b..ea2a59e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -135,4 +135,7 @@ ntups = QueryOperators.query([@NT(a=1, b=2, c=3), @NT(a=4, b=5, c=6)]) @test show(Core.CoreSTDOUT(), enum) == nothing @test show(Core.CoreSTDOUT(), ntups) == nothing +include("test_enumerable_map.jl") +include("test_enumerable_filter.jl") + end diff --git a/test/test_enumerable_filter.jl b/test/test_enumerable_filter.jl new file mode 100644 index 0000000..d11210f --- /dev/null +++ b/test/test_enumerable_filter.jl @@ -0,0 +1,27 @@ +using QueryOperators +using Base.Test + +@testset "filter" begin + +X = [1,2,3,4] + +# Test with eltype known + +a = QueryOperators.@filter(QueryOperators.query(X), i->i%2==0) +aa = collect(a) + +@test Base.iteratoreltype(typeof(a))==Base.HasEltype() +@test Base.iteratorsize(typeof(a)) == Base.SizeUnknown() +@test aa == [2,4] + +# Test with eltype unknown + +b = QueryOperators.@filter(QueryOperators.query(i for i in X), i->i%2==0) +bb = collect(b) + +@test Base.iteratoreltype(typeof(b))==Base.EltypeUnknown() +@test Base.iteratorsize(typeof(b)) == Base.SizeUnknown() +@test bb == [2,4] +@test eltype(bb) == Int + +end diff --git a/test/test_enumerable_map.jl b/test/test_enumerable_map.jl new file mode 100644 index 0000000..f450525 --- /dev/null +++ b/test/test_enumerable_map.jl @@ -0,0 +1,40 @@ +using QueryOperators +using Base.Test + +@testset "map" begin + +X = [1,2,3,4] + +# Test with eltype known + +a = QueryOperators.@map(QueryOperators.query(X), i->i^2) +aa = collect(a) + +@test Base.iteratoreltype(typeof(a))==Base.HasEltype() +@test Base.iteratorsize(typeof(a)) == Base.HasLength() +@test length(a) == 4 +@test aa == [1,4,9,16] + +# Test with eltype unknown + +b = QueryOperators.@map(QueryOperators.query(i for i in X), i->i) +bb = collect(b) + +@test Base.iteratoreltype(typeof(b))==Base.EltypeUnknown() +@test Base.iteratorsize(typeof(b)) == Base.HasLength() +@test length(b) == 4 +@test bb == [1,2,3,4] +@test eltype(bb) == Int + +# Test with known source eltype, but inference gives up + +c = QueryOperators.@map(QueryOperators.query(X), i->i>10 ? 2 : 4.) +cc = collect(c) + +@test Base.iteratoreltype(typeof(c))==Base.EltypeUnknown() +@test Base.iteratorsize(typeof(c)) == Base.HasLength() +@test length(c) == 4 +@test cc == [4.,4.,4.,4.] +@test eltype(cc) == Float64 + +end