From 69f67544b430e0cd71ac9c1325b604818d654b2e Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Wed, 24 May 2023 17:23:36 +0300 Subject: [PATCH 01/13] added new constructor --- src/array.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/array.jl b/src/array.jl index 6350eb9..fb1e818 100644 --- a/src/array.jl +++ b/src/array.jl @@ -15,6 +15,7 @@ end OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I}(indices, L) OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L) OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L) +OneHotArray(indices, L, axis::Int) = PermutedDimsArray(OneHotArray(indices, L), [axis, 1]) _indices(x::OneHotArray) = x.indices _indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) = @@ -69,7 +70,7 @@ end # the method above is faster on the CPU but will scalar index on the GPU # so we define the method below to pass the extra indices directly to GPU array function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray}, - i::Int, + i::Int, I::Vararg{Any, N}) where N @boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...))) return x.indices[I...] .== i @@ -154,5 +155,5 @@ Base.map(f, x::OneHotLike) = Base.broadcast(f, x) Base.argmax(x::OneHotLike; dims = Colon()) = (_isonehot(x) && dims == 1) ? - reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : + reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : invoke(argmax, Tuple{AbstractArray}, x; dims = dims) From db5e5afe464018f146ced24f113dde387f09147c Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Wed, 24 May 2023 19:10:03 +0300 Subject: [PATCH 02/13] fix for more than 2 dimensions --- src/array.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/array.jl b/src/array.jl index fb1e818..1b59016 100644 --- a/src/array.jl +++ b/src/array.jl @@ -1,8 +1,8 @@ """ OneHotArray{T, N, M, I} <: AbstractArray{Bool, M} - OneHotArray(indices, L) + OneHotArray(indices, L, [axis=1]) -A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, 1) == L` and `sum(A, dims=1) == 1`) +A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, axis) == L` and `sum(A, dims=axis) == 1`) stored as a compact `N == M-1`-dimensional array of indices. Typically constructed by [`onehot`](@ref) and [`onehotbatch`](@ref). @@ -15,7 +15,10 @@ end OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I}(indices, L) OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L) OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L) -OneHotArray(indices, L, axis::Int) = PermutedDimsArray(OneHotArray(indices, L), [axis, 1]) +function OneHotArray(indices, L, axis::Int) + a = collect(1:length(size(indices))+1) + PermutedDimsArray(OneHotArray(indices, L), insert!(a, 1, popat!(a, axis))) +end _indices(x::OneHotArray) = x.indices _indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) = From 9a0e1a220a36e2371abe2d9783333229110794ed Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Thu, 25 May 2023 01:34:56 +0300 Subject: [PATCH 03/13] passing tests --- test/array.jl | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/array.jl b/test/array.jl index 70247a1..a24698d 100644 --- a/test/array.jl +++ b/test/array.jl @@ -3,12 +3,14 @@ ov2 = OneHotVector(rand(1:11), 11) om = OneHotMatrix(rand(1:10, 5), 10) om2 = OneHotMatrix(rand(1:11, 5), 11) oa = OneHotArray(rand(1:10, 5, 5), 10) +oa2 = OneHotArray(rand(1:10, 5, 5), 10, 2) # sizes @testset "Base.size" begin @test size(ov) == (10,) @test size(om) == (10, 5) @test size(oa) == (10, 5, 5) + @test size(oa2) == (5, 10, 5) end @testset "Indexing" begin @@ -32,18 +34,30 @@ end @test oa[:, :, :] == oa @test oa[:] == reshape(oa, :) + @test oa2[3, 3, 3] == (oa2.parent.indices[3, 3] == 3) + @test oa2[3, :, 3] == OneHotVector(oa2.parent.indices[3, 3], 10) + @test oa2[:, 3, 3] == (oa2.parent.indices[:, 3] .== 3) + @test oa2[:, 3, :] == (oa2.parent.indices .== 3) + @test oa2[3, :, :] == OneHotMatrix(oa2.parent.indices[3, :], 10) + @test oa2[:, :, :] == oa2 + @test oa2[:] == reshape(oa2, :) + # cartesian indexing @test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3] + @test oa2[CartesianIndex(3, 3, 3)] == oa2[3, 3, 3] # linear indexing @test om[11] == om[1, 2] @test oa[52] == oa[2, 1, 2] + @test oa2[55] == oa2[1, 2, 2] # bounds checks @test_throws BoundsError ov[0] @test_throws BoundsError om[2, -1] @test_throws BoundsError oa[11, 5, 5] @test_throws BoundsError oa[:, :] + @test_throws BoundsError oa2[5, 11, 5] + @test_throws BoundsError oa2[:, :] end @testset "Concatenating" begin @@ -64,6 +78,9 @@ end @test cat(oa, oa; dims = 3) isa OneHotArray @test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1) + @test cat(oa2, oa2; dims = 3) == OneHotArray(cat(oa2.parent.indices, oa2.parent.indices; dims = 2), 10, 2) + @test cat(oa2, oa2; dims = 2) == cat(collect(oa2), collect(oa2); dims = 2) + # stack @test stack([ov, ov]) == hcat(ov, ov) @test stack([ov, ov, ov]) isa OneHotMatrix @@ -96,6 +113,18 @@ end @test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10)) @test OneHotArrays._fast_argmax(r) == collect(reshape(oa.indices, :)) end + + @testset "w/ cat" begin + r = reshape(oa2, 10, :) + @test vcat(r, r) isa Array{Bool} + end + + @testset "w/ argmax" begin + oa2p = PermutedDimsArray(oa2, [2,1,3]) + r = reshape(oa2p, 10, :) + @test argmax(r) == argmax(OneHotMatrix(reshape(oa2p.parent.parent.indices, :), 10)) + @test stack(collect(Tuple.(OneHotArrays._fast_argmax(r))))[1,:] == collect(reshape(oa2p.parent.parent.indices, :)) + end end @testset "Base.argmax" begin @@ -106,9 +135,13 @@ end @test argmax(om; dims = 2) == argmax(convert(Array{Bool}, om); dims = 2) @test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1) @test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3) + @test argmax(oa2; dims = 2) == argmax(convert(Array{Bool}, oa2); dims = 2) + @test argmax(oa2; dims = 3) == argmax(convert(Array{Bool}, oa2); dims = 3) end @testset "Forward map to broadcast" begin @test map(identity, oa) == oa @test map(x -> 2 * x, oa) == 2 .* oa + @test map(identity, oa2) == oa2 + @test map(x -> 2 * x, oa2) == 2 .* oa2 end From 1f6599ce506c8274ec508350674917050661a9e4 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Thu, 25 May 2023 18:23:52 +0300 Subject: [PATCH 04/13] fixed test (2) --- test/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/array.jl b/test/array.jl index a24698d..ab56af4 100644 --- a/test/array.jl +++ b/test/array.jl @@ -49,7 +49,7 @@ end # linear indexing @test om[11] == om[1, 2] @test oa[52] == oa[2, 1, 2] - @test oa2[55] == oa2[1, 2, 2] + @test oa2[56] == oa2[1, 2, 2] # bounds checks @test_throws BoundsError ov[0] From d87edb9833b95bcd006a66c18a7655a606c80471 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Thu, 25 May 2023 19:02:24 +0300 Subject: [PATCH 05/13] clearer permutation --- src/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array.jl b/src/array.jl index 1b59016..2b50a4e 100644 --- a/src/array.jl +++ b/src/array.jl @@ -16,8 +16,8 @@ OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L) OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L) function OneHotArray(indices, L, axis::Int) - a = collect(1:length(size(indices))+1) - PermutedDimsArray(OneHotArray(indices, L), insert!(a, 1, popat!(a, axis))) + perm = ntuple(d -> (d==axis ? 1 : (d==1 ? axis : d)), length(size(indices))+1) + PermutedDimsArray(OneHotArray(indices, L), perm) end _indices(x::OneHotArray) = x.indices From 21bb74f671c0ef48fd5834472752dd6f69aa920a Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Fri, 26 May 2023 09:49:01 +0300 Subject: [PATCH 06/13] moved functionality to onehotbatch and updated tests --- src/array.jl | 4 ---- src/onehot.jl | 12 ++++++++++++ test/array.jl | 33 --------------------------------- test/onehot.jl | 10 ++++++++++ 4 files changed, 22 insertions(+), 37 deletions(-) diff --git a/src/array.jl b/src/array.jl index 2b50a4e..e62f2ba 100644 --- a/src/array.jl +++ b/src/array.jl @@ -15,10 +15,6 @@ end OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I}(indices, L) OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L) OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L) -function OneHotArray(indices, L, axis::Int) - perm = ntuple(d -> (d==axis ? 1 : (d==1 ? axis : d)), length(size(indices))+1) - PermutedDimsArray(OneHotArray(indices, L), perm) -end _indices(x::OneHotArray) = x.indices _indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) = diff --git a/src/onehot.jl b/src/onehot.jl index d2d5e9d..0aa1752 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -81,6 +81,18 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl 3 6 15 3 9 3 12 3 6 15 3 ``` """ + +function onehotbatch(args...; dims::Integer) + out = onehotbatch(args...) + if dims==1 + out + else + indices = args[1] + perm = ntuple(d -> d==dims ? 1 : (d==1 ? dims : d), length(size(indices))+1) + PermutedDimsArray(out, perm) + end +end + onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) function _onehotbatch(data, labels) diff --git a/test/array.jl b/test/array.jl index ab56af4..70247a1 100644 --- a/test/array.jl +++ b/test/array.jl @@ -3,14 +3,12 @@ ov2 = OneHotVector(rand(1:11), 11) om = OneHotMatrix(rand(1:10, 5), 10) om2 = OneHotMatrix(rand(1:11, 5), 11) oa = OneHotArray(rand(1:10, 5, 5), 10) -oa2 = OneHotArray(rand(1:10, 5, 5), 10, 2) # sizes @testset "Base.size" begin @test size(ov) == (10,) @test size(om) == (10, 5) @test size(oa) == (10, 5, 5) - @test size(oa2) == (5, 10, 5) end @testset "Indexing" begin @@ -34,30 +32,18 @@ end @test oa[:, :, :] == oa @test oa[:] == reshape(oa, :) - @test oa2[3, 3, 3] == (oa2.parent.indices[3, 3] == 3) - @test oa2[3, :, 3] == OneHotVector(oa2.parent.indices[3, 3], 10) - @test oa2[:, 3, 3] == (oa2.parent.indices[:, 3] .== 3) - @test oa2[:, 3, :] == (oa2.parent.indices .== 3) - @test oa2[3, :, :] == OneHotMatrix(oa2.parent.indices[3, :], 10) - @test oa2[:, :, :] == oa2 - @test oa2[:] == reshape(oa2, :) - # cartesian indexing @test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3] - @test oa2[CartesianIndex(3, 3, 3)] == oa2[3, 3, 3] # linear indexing @test om[11] == om[1, 2] @test oa[52] == oa[2, 1, 2] - @test oa2[56] == oa2[1, 2, 2] # bounds checks @test_throws BoundsError ov[0] @test_throws BoundsError om[2, -1] @test_throws BoundsError oa[11, 5, 5] @test_throws BoundsError oa[:, :] - @test_throws BoundsError oa2[5, 11, 5] - @test_throws BoundsError oa2[:, :] end @testset "Concatenating" begin @@ -78,9 +64,6 @@ end @test cat(oa, oa; dims = 3) isa OneHotArray @test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1) - @test cat(oa2, oa2; dims = 3) == OneHotArray(cat(oa2.parent.indices, oa2.parent.indices; dims = 2), 10, 2) - @test cat(oa2, oa2; dims = 2) == cat(collect(oa2), collect(oa2); dims = 2) - # stack @test stack([ov, ov]) == hcat(ov, ov) @test stack([ov, ov, ov]) isa OneHotMatrix @@ -113,18 +96,6 @@ end @test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10)) @test OneHotArrays._fast_argmax(r) == collect(reshape(oa.indices, :)) end - - @testset "w/ cat" begin - r = reshape(oa2, 10, :) - @test vcat(r, r) isa Array{Bool} - end - - @testset "w/ argmax" begin - oa2p = PermutedDimsArray(oa2, [2,1,3]) - r = reshape(oa2p, 10, :) - @test argmax(r) == argmax(OneHotMatrix(reshape(oa2p.parent.parent.indices, :), 10)) - @test stack(collect(Tuple.(OneHotArrays._fast_argmax(r))))[1,:] == collect(reshape(oa2p.parent.parent.indices, :)) - end end @testset "Base.argmax" begin @@ -135,13 +106,9 @@ end @test argmax(om; dims = 2) == argmax(convert(Array{Bool}, om); dims = 2) @test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1) @test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3) - @test argmax(oa2; dims = 2) == argmax(convert(Array{Bool}, oa2); dims = 2) - @test argmax(oa2; dims = 3) == argmax(convert(Array{Bool}, oa2); dims = 3) end @testset "Forward map to broadcast" begin @test map(identity, oa) == oa @test map(x -> 2 * x, oa) == 2 .* oa - @test map(identity, oa2) == oa2 - @test map(x -> 2 * x, oa2) == 2 .* oa2 end diff --git a/test/onehot.jl b/test/onehot.jl index fffac19..c4ed3fd 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -16,6 +16,7 @@ @test onehotbatch("zbc", ('a', 'b', 'c'), 'a') == Bool[1 0 0; 0 1 0; 0 0 1] @test onehotbatch([10, 20], [30, 40, 50], 30) == Bool[1 1; 0 0; 0 0] + @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c]) @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c)) @@ -69,3 +70,12 @@ end @test y[:,1] isa OneHotVector @test y[:,:] isa OneHotMatrix end + +@testset "onehotbatch dims" begin + # basic tests + @test onehotbatch([20, 10], 10:10:30; dims=2) == Bool[0 1 0; 1 0 0] + @test onehotbatch([10, 20], [30, 40, 50], 30; dims=2) == Bool[1 0 0; 1 0 0] + # higher dimensions + @test size(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=2)) == (3, 12, 4) # test shape + @test sum(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=2), dims=2)[:] == ones(12) # test onehot on the second dim +end \ No newline at end of file From ecf70d2339628108af7346d05aaeee33cdeafc2d Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Fri, 26 May 2023 09:52:48 +0300 Subject: [PATCH 07/13] revert some minor changes --- src/array.jl | 8 ++++---- test/onehot.jl | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/array.jl b/src/array.jl index e62f2ba..b455f76 100644 --- a/src/array.jl +++ b/src/array.jl @@ -1,8 +1,8 @@ """ OneHotArray{T, N, M, I} <: AbstractArray{Bool, M} - OneHotArray(indices, L, [axis=1]) + OneHotArray(indices, L) -A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, axis) == L` and `sum(A, dims=axis) == 1`) +A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, 1) == L` and `sum(A, dims=1) == 1`) stored as a compact `N == M-1`-dimensional array of indices. Typically constructed by [`onehot`](@ref) and [`onehotbatch`](@ref). @@ -69,7 +69,7 @@ end # the method above is faster on the CPU but will scalar index on the GPU # so we define the method below to pass the extra indices directly to GPU array function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray}, - i::Int, + i::Int, I::Vararg{Any, N}) where N @boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...))) return x.indices[I...] .== i @@ -154,5 +154,5 @@ Base.map(f, x::OneHotLike) = Base.broadcast(f, x) Base.argmax(x::OneHotLike; dims = Colon()) = (_isonehot(x) && dims == 1) ? - reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : + reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : invoke(argmax, Tuple{AbstractArray}, x; dims = dims) diff --git a/test/onehot.jl b/test/onehot.jl index c4ed3fd..924322b 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -16,7 +16,6 @@ @test onehotbatch("zbc", ('a', 'b', 'c'), 'a') == Bool[1 0 0; 0 1 0; 0 0 1] @test onehotbatch([10, 20], [30, 40, 50], 30) == Bool[1 1; 0 0; 0 0] - @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c]) @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c)) From a16b47983c200e8fb4587fdb0540c528b3f464d8 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Fri, 26 May 2023 09:53:38 +0300 Subject: [PATCH 08/13] revert minor change --- src/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array.jl b/src/array.jl index b455f76..6350eb9 100644 --- a/src/array.jl +++ b/src/array.jl @@ -69,7 +69,7 @@ end # the method above is faster on the CPU but will scalar index on the GPU # so we define the method below to pass the extra indices directly to GPU array function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray}, - i::Int, + i::Int, I::Vararg{Any, N}) where N @boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...))) return x.indices[I...] .== i From 6ae60c2fcc313c5db6f8363c22b67d06d625fea9 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Fri, 26 May 2023 10:25:19 +0300 Subject: [PATCH 09/13] added documentation, added a failing test --- src/onehot.jl | 18 +++++++++++++++++- test/onehot.jl | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 0aa1752..71c5353 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -48,7 +48,7 @@ end _findval(val, labels::Tuple{}, i::Integer) = nothing """ - onehotbatch(xs, labels, [default]) + onehotbatch(xs, labels, [default]; dims::Integer=1) Returns a [`OneHotMatrix`](@ref) where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot). This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the @@ -64,6 +64,8 @@ i.e. `result[:, k...] == onehot(xs[k...], labels)`. Note that `xs` can be any iterable, such as a string. And that using a tuple for `labels` will often speed up construction, certainly for less than 32 classes. +If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one. + # Examples ```jldoctest julia> oh = onehotbatch("abracadabra", 'a':'e', 'e') @@ -74,6 +76,20 @@ julia> oh = onehotbatch("abracadabra", 'a':'e', 'e') ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ + julia> oh = onehotbatch("abracadabra", 'a':'e', 'e'; dims=2) +5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool: + 1 ⋅ ⋅ ⋅ ⋅ + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ ⋅ 1 + 1 ⋅ ⋅ ⋅ ⋅ + ⋅ ⋅ 1 ⋅ ⋅ + 1 ⋅ ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ 1 ⋅ + 1 ⋅ ⋅ ⋅ ⋅ + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ ⋅ 1 + 1 ⋅ ⋅ ⋅ ⋅ + julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently 3×11 Matrix{Int64}: 1 4 13 1 7 1 10 1 4 13 1 diff --git a/test/onehot.jl b/test/onehot.jl index 924322b..5dfaad6 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -77,4 +77,6 @@ end # higher dimensions @test size(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=2)) == (3, 12, 4) # test shape @test sum(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=2), dims=2)[:] == ones(12) # test onehot on the second dim + # works with strings + @test onehotbatch("ba", 'a':'c'; dims=2) == Bool[0 1 0; 1 0 0] end \ No newline at end of file From 30065367fd4202466abe613a8b45c2031e43ac77 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Fri, 26 May 2023 12:22:55 +0300 Subject: [PATCH 10/13] need to solve type inference problem --- src/onehot.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 71c5353..252f4cf 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -98,19 +98,18 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl ``` """ -function onehotbatch(args...; dims::Integer) - out = onehotbatch(args...) +onehotbatch(data::String, labels, default...; dims::Integer=1) = onehotbatch(collect(data), labels, default...; dims=dims) +onehotbatch(data::AbstractRange, labels, default...; dims::Integer=1) = onehotbatch(collect(data), labels, default...; dims=dims) +function onehotbatch(data::AbstractArray, labels, default...; dims::Integer=1) + out = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) if dims==1 out else - indices = args[1] - perm = ntuple(d -> d==dims ? 1 : (d==1 ? dims : d), length(size(indices))+1) + perm = ntuple(d -> d==dims ? 1 : (d==1 ? dims : d), length(size(data))+1) PermutedDimsArray(out, perm) end end -onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) - function _onehotbatch(data, labels) indices = UInt32[something(_findval(i, labels), 0) for i in data] if 0 in indices From e9f8398b548fc013269ec4d9b20382dfd0da81e4 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Wed, 26 Jul 2023 11:24:46 +0300 Subject: [PATCH 11/13] fixed type stability --- src/onehot.jl | 53 +++++++++++++++++++++++++++----------------------- test/onehot.jl | 13 ++++++++----- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 252f4cf..1978dee 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -48,7 +48,7 @@ end _findval(val, labels::Tuple{}, i::Integer) = nothing """ - onehotbatch(xs, labels, [default]; dims::Integer=1) + onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1}) Returns a [`OneHotMatrix`](@ref) where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot). This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the @@ -98,35 +98,20 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl ``` """ -onehotbatch(data::String, labels, default...; dims::Integer=1) = onehotbatch(collect(data), labels, default...; dims=dims) -onehotbatch(data::AbstractRange, labels, default...; dims::Integer=1) = onehotbatch(collect(data), labels, default...; dims=dims) -function onehotbatch(data::AbstractArray, labels, default...; dims::Integer=1) +onehotbatch(data::String, labels, default...; dims::Val{D} = Val(1)) where D = onehotbatch(collect(data), labels, default...; dims=dims) +onehotbatch(data::AbstractRange, labels, default...; dims::Val{D} = Val(1)) where D = onehotbatch(collect(data), labels, default...; dims=dims) +function onehotbatch(data::AbstractArray{<:Any, N}, labels, default...; dims::Val{D}= Val(1)) where {N,D} out = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) - if dims==1 + if D==1 out else - perm = ntuple(d -> d==dims ? 1 : (d==1 ? dims : d), length(size(data))+1) - PermutedDimsArray(out, perm) + perm = Tuple(ntuple(d -> d==D ? 1 : (d==1 ? D : d), N+1)) + # need to use obtuse PermutedDimsArray constructor in order to stabilise permuation types + iperm = invperm(perm) + PermutedDimsArray{eltype(out),N+1,(perm...,),(iperm...,),typeof(out)}(out) end end -function _onehotbatch(data, labels) - indices = UInt32[something(_findval(i, labels), 0) for i in data] - if 0 in indices - for x in data - isnothing(_findval(x, labels)) && error("Value $x not found in labels") - end - end - return OneHotArray(indices, length(labels)) -end - -function _onehotbatch(data, labels, default) - default_index = _findval(default, labels) - isnothing(default_index) && error("Default value $default is not in labels") - indices = UInt32[something(_findval(i, labels), default_index) for i in data] - return OneHotArray(indices, length(labels)) -end - function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) lo, hi = extrema(data) lo < first(labels) && error("Value $lo not found in labels") @@ -135,6 +120,8 @@ function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{< indices = UInt32.(data .+ offset) return OneHotArray(indices, length(labels)) end +onehotbatch(data::AbstractRange{<:Integer}, labels::AbstractUnitRange{<:Integer}) = onehotbatch(collect(data), labels) + # That bounds check with extrema synchronises on GPU, much slower than rest of the function, # hence add a special method, with a less helpful error message: function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) @@ -147,6 +134,24 @@ function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRang return OneHotArray(indices, length(labels)) end + +function _onehotbatch(data, labels) + indices = UInt32[something(_findval(i, labels), 0) for i in data] + if 0 in indices + for x in data + isnothing(_findval(x, labels)) && error("Value $x not found in labels") + end + end + return OneHotArray(indices, length(labels)) +end + +function _onehotbatch(data, labels, default) + default_index = _findval(default, labels) + isnothing(default_index) && error("Default value $default is not in labels") + indices = UInt32[something(_findval(i, labels), default_index) for i in data] + return OneHotArray(indices, length(labels)) +end + """ onecold(y::AbstractArray, labels = 1:size(y,1)) diff --git a/test/onehot.jl b/test/onehot.jl index 5dfaad6..27c142c 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -72,11 +72,14 @@ end @testset "onehotbatch dims" begin # basic tests - @test onehotbatch([20, 10], 10:10:30; dims=2) == Bool[0 1 0; 1 0 0] - @test onehotbatch([10, 20], [30, 40, 50], 30; dims=2) == Bool[1 0 0; 1 0 0] + @test onehotbatch([20, 10], 10:10:30; dims=Val(2)) == Bool[0 1 0; 1 0 0] + @test onehotbatch([10, 20], [30, 40, 50], 30; dims=Val(2)) == Bool[1 0 0; 1 0 0] # higher dimensions - @test size(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=2)) == (3, 12, 4) # test shape - @test sum(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=2), dims=2)[:] == ones(12) # test onehot on the second dim + @test size(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=Val(2))) == (3, 12, 4) # test shape + @test sum(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=Val(2)), dims=2)[:] == ones(12) # test onehot on the second dim # works with strings - @test onehotbatch("ba", 'a':'c'; dims=2) == Bool[0 1 0; 1 0 0] + @test onehotbatch("ba", 'a':'c'; dims=Val(2)) == Bool[0 1 0; 1 0 0] + + @test @inferred(onehotbatch([20, 10], 10:10:30; dims=Val(2))) == Bool[0 1 0; 1 0 0] + @test @inferred(onehotbatch([40, 10], (10,20,30), 20; dims=Val(2))) == Bool[0 1 0; 1 0 0] end \ No newline at end of file From 7df567507ebaf01febe8b1bdd32985b7c5f495bd Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Wed, 26 Jul 2023 15:42:20 +0300 Subject: [PATCH 12/13] doc fix --- src/onehot.jl | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 1978dee..28ee322 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -76,25 +76,18 @@ julia> oh = onehotbatch("abracadabra", 'a':'e', 'e') ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ - julia> oh = onehotbatch("abracadabra", 'a':'e', 'e'; dims=2) -5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool: - 1 ⋅ ⋅ ⋅ ⋅ - ⋅ 1 ⋅ ⋅ ⋅ - ⋅ ⋅ ⋅ ⋅ 1 - 1 ⋅ ⋅ ⋅ ⋅ - ⋅ ⋅ 1 ⋅ ⋅ - 1 ⋅ ⋅ ⋅ ⋅ - ⋅ ⋅ ⋅ 1 ⋅ - 1 ⋅ ⋅ ⋅ ⋅ - ⋅ 1 ⋅ ⋅ ⋅ - ⋅ ⋅ ⋅ ⋅ 1 - 1 ⋅ ⋅ ⋅ ⋅ - julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently 3×11 Matrix{Int64}: 1 4 13 1 7 1 10 1 4 13 1 2 5 14 2 8 2 11 2 5 14 2 3 6 15 3 9 3 12 3 6 15 3 + +# One hot vectors on the second axis +julia> onehotbatch([0, 0, 7], 0:9; dims=Val(2)) +3×10 PermutedDimsArray(OneHotMatrix(::Vector{UInt32}), (2, 1)) with eltype Bool: + 1 0 0 0 0 0 0 0 0 0 + 1 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 1 0 0 ``` """ From 0b43cbb738790b7adb8d8207b8e696fed323e1ca Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Wed, 27 Dec 2023 17:28:17 +0200 Subject: [PATCH 13/13] save work for future --- src/onehot.jl | 46 ++++++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 28ee322..dc58bdd 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -91,44 +91,50 @@ julia> onehotbatch([0, 0, 7], 0:9; dims=Val(2)) ``` """ -onehotbatch(data::String, labels, default...; dims::Val{D} = Val(1)) where D = onehotbatch(collect(data), labels, default...; dims=dims) -onehotbatch(data::AbstractRange, labels, default...; dims::Val{D} = Val(1)) where D = onehotbatch(collect(data), labels, default...; dims=dims) -function onehotbatch(data::AbstractArray{<:Any, N}, labels, default...; dims::Val{D}= Val(1)) where {N,D} - out = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) - if D==1 - out - else - perm = Tuple(ntuple(d -> d==D ? 1 : (d==1 ? D : d), N+1)) - # need to use obtuse PermutedDimsArray constructor in order to stabilise permuation types - iperm = invperm(perm) - PermutedDimsArray{eltype(out),N+1,(perm...,),(iperm...,),typeof(out)}(out) - end +# developer note: +# onehotbatch is intended as the api and includes bounds checks +# _onehotbatch is intended as the implementation which includes membership checks +# _onehotbatch_fast same as above but without membership checks which would be slow on GPU + +function onehotbatch(data::String, labels, default...; dims::Val{D} = Val(1)) where D + _onehotbatch(dims, data, length(labels) < 32 ? Tuple(labels) : labels, default...) end -function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) +function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}, default...; dims::Val{D} = Val(1)) where D lo, hi = extrema(data) lo < first(labels) && error("Value $lo not found in labels") hi > last(labels) && error("Value $hi not found in labels") offset = 1 - first(labels) indices = UInt32.(data .+ offset) - return OneHotArray(indices, length(labels)) + _onehotbatch(dims, indices, length(labels) < 32 ? Tuple(labels) : labels) end -onehotbatch(data::AbstractRange{<:Integer}, labels::AbstractUnitRange{<:Integer}) = onehotbatch(collect(data), labels) # That bounds check with extrema synchronises on GPU, much slower than rest of the function, # hence add a special method, with a less helpful error message: -function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) +function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer}, default...; dims::Val{D} = Val(1)) where D offset = 1 - first(labels) indices = map(data) do datum i = UInt32(datum + offset) checkbounds(labels, i) i end - return OneHotArray(indices, length(labels)) + _onehotbatch_fast(dims, indices, length(labels) < 32 ? Tuple(labels) : labels) +end +# _onehotbatch_fast does not have the bounds checks in _onehotbatch which would slow down GPU, but allows permute +_onehotbatch_fast(dims::Val{D}, indices, labels) where D = _permute(dims, _onehotbatch_fast(Val(1), indices, labels)) +_onehotbatch_fast(::Val{1}, indices, labels) = OneHotArray(indices, length(labels)) + +_onehotbatch(dims::Val, data, labels, default...) = _permute(dims, _onehotbatch(Val(1), data, labels, default...)) + +_permute(::Val{2}, array::OneHotArray{<:Any, 1, 2}) = transpose(array) +function _permute(::Val{d}, array::OneHotArray{<:Any, N,M}) where {d, N, M} + perm = Tuple(ntuple(d -> d==D ? 1 : (d==1 ? D : d), M)) + # need to use obtuse PermutedDimsArray constructor in order to stabilise permuation types + iperm = invperm(perm) + PermutedDimsArray{eltype(out),M,(perm...,),(iperm...,),typeof(out)}(out) end - -function _onehotbatch(data, labels) +function _onehotbatch(::Val{1}, data, labels) indices = UInt32[something(_findval(i, labels), 0) for i in data] if 0 in indices for x in data @@ -138,7 +144,7 @@ function _onehotbatch(data, labels) return OneHotArray(indices, length(labels)) end -function _onehotbatch(data, labels, default) +function _onehotbatch(::Val{1}, data, labels, default) default_index = _findval(default, labels) isnothing(default_index) && error("Default value $default is not in labels") indices = UInt32[something(_findval(i, labels), default_index) for i in data]