Skip to content

Commit

Permalink
Merge pull request #1888 from TLipede/speed-onehot
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack authored Mar 4, 2022
2 parents dc479e3 + 783fa7d commit 0b3e8c5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
21 changes: 18 additions & 3 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,24 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
3 6 15 3 9 3 12 3 6 15 3
```
"""
onehotbatch(ls, labels, default...) = _onehotbatch(ls, length(labels) < 32 ? Tuple(labels) : labels, default...)
# NB function barier:
_onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l in ls])
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)

function _onehotbatch(data, labels)
indices = UInt32[something(_findval(i, labels), 0) for i in data]
if 0 in indices
for x in data
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
end
end
return OneHotArray(indices, length(labels))
end

function _onehotbatch(data, labels, default)
default_index = _findval(default, labels)
isnothing(default_index) && error("Default value $default is not in labels")
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
return OneHotArray(indices, length(labels))
end

"""
onecold(y::AbstractArray, labels = 1:size(y,1))
Expand Down
2 changes: 2 additions & 0 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using Test
@test onehotbatch("abc", 'a':'c') == Bool[1 0 0; 0 1 0; 0 0 1]
@test onehotbatch("zbc", ('a', 'b', 'c'), 'a') == Bool[1 0 0; 0 1 0; 0 0 1]

@test onehotbatch([10, 20], [30, 40, 50], 30) == Bool[1 1; 0 0; 0 0]

@test_throws Exception onehotbatch([:a, :d], [:a, :b, :c])
@test_throws Exception onehotbatch([:a, :d], (:a, :b, :c))
@test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e)
Expand Down

0 comments on commit 0b3e8c5

Please sign in to comment.