-
-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Add axis permutedims #36
base: main
Are you sure you want to change the base?
Changes from 11 commits
69f6754
db5e5af
9a0e1a2
1f6599c
d87edb9
21bb74f
ecf70d2
a16b479
6ae60c2
3006536
e9f8398
7df5675
0b43cbb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -48,7 +48,7 @@ end | |||||
_findval(val, labels::Tuple{}, i::Integer) = nothing | ||||||
|
||||||
""" | ||||||
onehotbatch(xs, labels, [default]) | ||||||
onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1}) | ||||||
|
||||||
Returns a [`OneHotMatrix`](@ref) where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot). | ||||||
This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the | ||||||
|
@@ -64,6 +64,8 @@ i.e. `result[:, k...] == onehot(xs[k...], labels)`. | |||||
Note that `xs` can be any iterable, such as a string. And that using a tuple | ||||||
for `labels` will often speed up construction, certainly for less than 32 classes. | ||||||
|
||||||
If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# Examples | ||||||
```jldoctest | ||||||
julia> oh = onehotbatch("abracadabra", 'a':'e', 'e') | ||||||
|
@@ -74,30 +76,40 @@ julia> oh = onehotbatch("abracadabra", 'a':'e', 'e') | |||||
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ | ||||||
⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ | ||||||
|
||||||
julia> oh = onehotbatch("abracadabra", 'a':'e', 'e'; dims=2) | ||||||
nomadbl marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool: | ||||||
1 ⋅ ⋅ ⋅ ⋅ | ||||||
⋅ 1 ⋅ ⋅ ⋅ | ||||||
⋅ ⋅ ⋅ ⋅ 1 | ||||||
1 ⋅ ⋅ ⋅ ⋅ | ||||||
⋅ ⋅ 1 ⋅ ⋅ | ||||||
1 ⋅ ⋅ ⋅ ⋅ | ||||||
⋅ ⋅ ⋅ 1 ⋅ | ||||||
1 ⋅ ⋅ ⋅ ⋅ | ||||||
⋅ 1 ⋅ ⋅ ⋅ | ||||||
⋅ ⋅ ⋅ ⋅ 1 | ||||||
1 ⋅ ⋅ ⋅ ⋅ | ||||||
|
||||||
julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently | ||||||
3×11 Matrix{Int64}: | ||||||
1 4 13 1 7 1 10 1 4 13 1 | ||||||
2 5 14 2 8 2 11 2 5 14 2 | ||||||
3 6 15 3 9 3 12 3 6 15 3 | ||||||
``` | ||||||
""" | ||||||
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My suggestion for how to write this would be this. Add the keyword dims but leave the basic path as close to untouched as you can, like so: onehotbatch(data, labels, default...; dims=Val(1)) = _onehotbatch(dims, data, length(labels) < 32 ? Tuple(labels) : labels, default...)
function _onehotbatch(::Val{1}, data, labels)
# as before
return OneHotArray(indices, length(labels))
end
function _onehotbatch(::Val{1}, data, labels, default)
# as before
return OneHotArray(indices, length(labels))
end In particular, this does not call Readers uninterested in permutations can stop there. But to handle them, make it obvious that we call the same path, and then permute it. _onehotbatch(dims::Integer, data, labels, default...) = _onehotbatch(Val(dims), data, labels, default...)
_onehotbatch(dims::Val, data, labels, default...) = _permute(dims, _onehotbatch(Val(1), data, labels, default...))
_permute(::Val{2}, array::OneHotArray{<:Any, 1, 2}) = transpose(array)
function _permute(::Val{d}, array::OneHotArray{<:Any, N,M}) where {d, N, M}
# this is where you compute perm, can use N or M, I forget...
PermutedDimsArray(array, perm)
end I made a special case for |
||||||
|
||||||
function _onehotbatch(data, labels) | ||||||
indices = UInt32[something(_findval(i, labels), 0) for i in data] | ||||||
if 0 in indices | ||||||
for x in data | ||||||
isnothing(_findval(x, labels)) && error("Value $x not found in labels") | ||||||
end | ||||||
onehotbatch(data::String, labels, default...; dims::Val{D} = Val(1)) where D = onehotbatch(collect(data), labels, default...; dims=dims) | ||||||
onehotbatch(data::AbstractRange, labels, default...; dims::Val{D} = Val(1)) where D = onehotbatch(collect(data), labels, default...; dims=dims) | ||||||
function onehotbatch(data::AbstractArray{<:Any, N}, labels, default...; dims::Val{D}= Val(1)) where {N,D} | ||||||
out = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) | ||||||
if D==1 | ||||||
out | ||||||
else | ||||||
perm = Tuple(ntuple(d -> d==D ? 1 : (d==1 ? D : d), N+1)) | ||||||
# need to use obtuse PermutedDimsArray constructor in order to stabilise permuation types | ||||||
iperm = invperm(perm) | ||||||
PermutedDimsArray{eltype(out),N+1,(perm...,),(iperm...,),typeof(out)}(out) | ||||||
end | ||||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
|
||||||
function _onehotbatch(data, labels, default) | ||||||
default_index = _findval(default, labels) | ||||||
isnothing(default_index) && error("Default value $default is not in labels") | ||||||
indices = UInt32[something(_findval(i, labels), default_index) for i in data] | ||||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
|
||||||
function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) | ||||||
|
@@ -108,6 +120,8 @@ function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{< | |||||
indices = UInt32.(data .+ offset) | ||||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
onehotbatch(data::AbstractRange{<:Integer}, labels::AbstractUnitRange{<:Integer}) = onehotbatch(collect(data), labels) | ||||||
|
||||||
# That bounds check with extrema synchronises on GPU, much slower than rest of the function, | ||||||
# hence add a special method, with a less helpful error message: | ||||||
function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) | ||||||
|
@@ -120,6 +134,24 @@ function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRang | |||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
|
||||||
|
||||||
function _onehotbatch(data, labels) | ||||||
indices = UInt32[something(_findval(i, labels), 0) for i in data] | ||||||
if 0 in indices | ||||||
for x in data | ||||||
isnothing(_findval(x, labels)) && error("Value $x not found in labels") | ||||||
end | ||||||
end | ||||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
|
||||||
function _onehotbatch(data, labels, default) | ||||||
default_index = _findval(default, labels) | ||||||
isnothing(default_index) && error("Default value $default is not in labels") | ||||||
indices = UInt32[something(_findval(i, labels), default_index) for i in data] | ||||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
|
||||||
""" | ||||||
onecold(y::AbstractArray, labels = 1:size(y,1)) | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.