-
-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Add axis permutedims #36
base: main
Are you sure you want to change the base?
Changes from 3 commits
69f6754
db5e5af
9a0e1a2
1f6599c
d87edb9
21bb74f
ecf70d2
a16b479
6ae60c2
3006536
e9f8398
7df5675
0b43cbb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
""" | ||
OneHotArray{T, N, M, I} <: AbstractArray{Bool, M} | ||
OneHotArray(indices, L) | ||
OneHotArray(indices, L, [axis=1]) | ||
|
||
A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, 1) == L` and `sum(A, dims=1) == 1`) | ||
A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, axis) == L` and `sum(A, dims=axis) == 1`) | ||
stored as a compact `N == M-1`-dimensional array of indices. | ||
|
||
Typically constructed by [`onehot`](@ref) and [`onehotbatch`](@ref). | ||
|
@@ -15,6 +15,10 @@ end | |
OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I}(indices, L) | ||
OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L) | ||
OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L) | ||
function OneHotArray(indices, L, axis::Int) | ||
a = collect(1:length(size(indices))+1) | ||
PermutedDimsArray(OneHotArray(indices, L), insert!(a, 1, popat!(a, axis))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The permutation can be computed without mutating an array, something like this:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Notice your second example gives There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought I may be wrong by an invperm, but did not check carefully. |
||
end | ||
|
||
_indices(x::OneHotArray) = x.indices | ||
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) = | ||
|
@@ -69,7 +73,7 @@ end | |
# the method above is faster on the CPU but will scalar index on the GPU | ||
# so we define the method below to pass the extra indices directly to GPU array | ||
function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray}, | ||
i::Int, | ||
i::Int, | ||
I::Vararg{Any, N}) where N | ||
@boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...))) | ||
return x.indices[I...] .== i | ||
|
@@ -154,5 +158,5 @@ Base.map(f, x::OneHotLike) = Base.broadcast(f, x) | |
|
||
Base.argmax(x::OneHotLike; dims = Colon()) = | ||
(_isonehot(x) && dims == 1) ? | ||
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : | ||
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : | ||
invoke(argmax, Tuple{AbstractArray}, x; dims = dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do add this, it should probably be a
dims::Integer
keyword ononehotbatch
. IMO it's weird if a type constructor does not return the stated type.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the sentiment, but it seems awkward to have to maintain a set of functions and increase their complexity to just have the same functionality..
If this is a complete no go then either the alternative implementation (which might have even more problems) or maybe add this functionality as a separate utility function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it more complex to alter the lower-case function than the upper-case type constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought some more about this point and I agree changing onehot/onehotbatch/etc is the better approach.
I'll take a closer look at the functions (still not very familiar with them all).
Which would you say would be appropriate?