diff --git a/src/onehot.jl b/src/onehot.jl index 86afd513dc..5c553db7c3 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -183,9 +183,24 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl 3 6 15 3 9 3 12 3 6 15 3 ``` """ -onehotbatch(ls, labels, default...) = _onehotbatch(ls, length(labels) < 32 ? Tuple(labels) : labels, default...) -# NB function barier: -_onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l in ls]) +onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) + +function _onehotbatch(data, labels) + indices = UInt32[something(_findval(i, labels), 0) for i in data] + if 0 in indices + for x in data + isnothing(_findval(x, labels)) && error("Value $x not found in labels") + end + end + return OneHotArray(indices, length(labels)) +end + +function _onehotbatch(data, labels, default) + default_index = _findval(default, labels) + isnothing(default_index) && error("Default value $default is not in labels") + indices = UInt32[something(_findval(i, labels), default_index) for i in data] + return OneHotArray(indices, length(labels)) +end """ onecold(y::AbstractArray, labels = 1:size(y,1)) diff --git a/test/onehot.jl b/test/onehot.jl index 6c29d807f5..91f64b763d 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -18,6 +18,8 @@ using Test @test onehotbatch("abc", 'a':'c') == Bool[1 0 0; 0 1 0; 0 0 1] @test onehotbatch("zbc", ('a', 'b', 'c'), 'a') == Bool[1 0 0; 0 1 0; 0 0 1] + @test onehotbatch([10, 20], [30, 40, 50], 30) == Bool[1 1; 0 0; 0 0] + @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c]) @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c)) @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e)