From 8db7e2a6f3586a9bfa95d46050caf65c28c1e726 Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Tue, 22 Feb 2022 19:17:39 +0000 Subject: [PATCH 01/14] Speeding up `onehotbatch` by constructing array directly --- src/onehot.jl | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 36345438a3..dc7f528fcb 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -183,9 +183,27 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl 3 6 15 3 9 3 12 3 6 15 3 ``` """ -onehotbatch(ls, labels, default...) = _onehotbatch(ls, length(labels) < 32 ? Tuple(labels) : labels, default...) -# NB function barier: -_onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l in ls]) +onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) + +# NB function barrier: +function _onehotbatch(data, labels) + n_labels = length(labels) + indices = _findval.(data, Ref(labels)) + if nothing in indices + unexpected_values = unique(data[indices .== -1]) + error("Values $unexpected_values are not in labels") + end + OneHotArray(indices, n_labels) +end + +function _onehotbatch(data, labels, default) + n_labels = length(labels) + indices = _findval.(data, Ref(labels)) + if nothing in indices + indices = replace!(indices, -1 => _findval(default, labels)) + end + return OneHotArray(indices, n_labels) +end """ onecold(y::AbstractArray, labels = 1:size(y,1)) From af976e06b8e391b859c88491201a4fcde2420d70 Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Tue, 22 Feb 2022 19:26:41 +0000 Subject: [PATCH 02/14] fixing error cases --- src/onehot.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index dc7f528fcb..f716b13585 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -190,7 +190,7 @@ function _onehotbatch(data, labels) n_labels = length(labels) indices = _findval.(data, Ref(labels)) if nothing in indices - unexpected_values = unique(data[indices .== -1]) + unexpected_values = unique(data[indices .== nothing]) error("Values $unexpected_values are not in labels") end OneHotArray(indices, n_labels) @@ -200,7 +200,7 @@ function _onehotbatch(data, labels, default) n_labels = length(labels) indices = _findval.(data, Ref(labels)) if nothing in indices - indices = replace!(indices, -1 => _findval(default, labels)) + indices = replace!(indices, nothing => _findval(default, labels)) end return OneHotArray(indices, n_labels) end From 0e7b6080371bff019c36cfd89399e1c365329dea Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Tue, 22 Feb 2022 20:00:42 +0000 Subject: [PATCH 03/14] can't replace in-place --- src/onehot.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index f716b13585..cbe8e969cf 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -190,7 +190,7 @@ function _onehotbatch(data, labels) n_labels = length(labels) indices = _findval.(data, Ref(labels)) if nothing in indices - unexpected_values = unique(data[indices .== nothing]) + unexpected_values = unique(data[indices .== nothinghl]) error("Values $unexpected_values are not in labels") end OneHotArray(indices, n_labels) @@ -200,7 +200,7 @@ function _onehotbatch(data, labels, default) n_labels = length(labels) indices = _findval.(data, Ref(labels)) if nothing in indices - indices = replace!(indices, nothing => _findval(default, labels)) + indices = replace(indices, nothing => _findval(default, labels)) end return OneHotArray(indices, n_labels) end From 5ed74b672df1162a86567d5ab9cb8be136fc6217 Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Tue, 22 Feb 2022 22:03:38 +0000 Subject: [PATCH 04/14] Update src/onehot.jl Co-authored-by: Kyle Daruwalla --- src/onehot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index cbe8e969cf..03004393ea 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -193,7 +193,7 @@ function _onehotbatch(data, labels) unexpected_values = unique(data[indices .== nothinghl]) error("Values $unexpected_values are not in labels") end - OneHotArray(indices, n_labels) + return OneHotArray(indices, n_labels) end function _onehotbatch(data, labels, default) From 5041a27525cc6af2145777a0f7b04bf8513c15ed Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Tue, 22 Feb 2022 22:03:48 +0000 Subject: [PATCH 05/14] Update src/onehot.jl Co-authored-by: Kyle Daruwalla --- src/onehot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 03004393ea..586d97e333 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -190,7 +190,7 @@ function _onehotbatch(data, labels) n_labels = length(labels) indices = _findval.(data, Ref(labels)) if nothing in indices - unexpected_values = unique(data[indices .== nothinghl]) + unexpected_values = unique(data[indices .== nothing]) error("Values $unexpected_values are not in labels") end return OneHotArray(indices, n_labels) From 2fd1515e664e2ad20ddaac4d707c3c943e633859 Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Tue, 22 Feb 2022 22:58:08 +0000 Subject: [PATCH 06/14] Update src/onehot.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/onehot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 586d97e333..b668fe64d6 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -188,7 +188,7 @@ onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? # NB function barrier: function _onehotbatch(data, labels) n_labels = length(labels) - indices = _findval.(data, Ref(labels)) + indices = map(x -> _findval(x, labels), data) if nothing in indices unexpected_values = unique(data[indices .== nothing]) error("Values $unexpected_values are not in labels") From d9eae397154d463e58891ff865fcc8883ea41279 Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Tue, 22 Feb 2022 23:08:09 +0000 Subject: [PATCH 07/14] changed broadcasts to map; new variable name for replaced indices --- src/onehot.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index b668fe64d6..1412c7967d 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -198,11 +198,15 @@ end function _onehotbatch(data, labels, default) n_labels = length(labels) - indices = _findval.(data, Ref(labels)) + indices = map(x -> _findval(x, labels), data) if nothing in indices - indices = replace(indices, nothing => _findval(default, labels)) + default_index = _findval(default, labels) + isnothing(default_index) && error("Default value $default_index is not in labels") + replaced_indices = replace(indices, nothing => default_index) + return OneHotArray(replaced_indices, n_labels) + else + return OneHotArray(indices, n_labels) end - return OneHotArray(indices, n_labels) end """ From 421b9b33c0051894603061ed5dfefae8fed8ff74 Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Wed, 23 Feb 2022 11:41:44 +0000 Subject: [PATCH 08/14] Added a method for strings --- src/onehot.jl | 1 + test/onehot.jl | 3 +++ 2 files changed, 4 insertions(+) diff --git a/src/onehot.jl b/src/onehot.jl index 1412c7967d..7bbc3a1e24 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -184,6 +184,7 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl ``` """ onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) +onehotbatch(data::AbstractString, labels, default...) = onehotbatch([_ for _ in data], labels, default...) # NB function barrier: function _onehotbatch(data, labels) diff --git a/test/onehot.jl b/test/onehot.jl index 6c29d807f5..97ac223da5 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -18,10 +18,13 @@ using Test @test onehotbatch("abc", 'a':'c') == Bool[1 0 0; 0 1 0; 0 0 1] @test onehotbatch("zbc", ('a', 'b', 'c'), 'a') == Bool[1 0 0; 0 1 0; 0 0 1] + @test onehotbatch([10, 20], [30, 40, 50], 30) == Bool[1 1; 0 0; 0 0] + @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c]) @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c)) @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e) @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c), :e) + @test_throws Exception onehotbatch([:a, :e], (:a, :b, :c), :d) floats = (0.0, -0.0, NaN, -NaN, Inf, -Inf) @test onecold(onehot(0.0, floats)) == 1 From d041014ca8746554aab37b33a68e67ffec04ef4b Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Wed, 23 Feb 2022 11:57:06 +0000 Subject: [PATCH 09/14] Error should be given on default value, not default index --- src/onehot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 7bbc3a1e24..9e98cade75 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -202,7 +202,7 @@ function _onehotbatch(data, labels, default) indices = map(x -> _findval(x, labels), data) if nothing in indices default_index = _findval(default, labels) - isnothing(default_index) && error("Default value $default_index is not in labels") + isnothing(default_index) && error("Default value $default is not in labels") replaced_indices = replace(indices, nothing => default_index) return OneHotArray(replaced_indices, n_labels) else From eb15e859d225bc467160f833dc6868766dd5f0c8 Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Wed, 23 Feb 2022 20:25:41 +0000 Subject: [PATCH 10/14] Removed string method; changed `map`s to a comprehensions --- src/onehot.jl | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 9e98cade75..15f1b2a443 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -184,30 +184,22 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl ``` """ onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) -onehotbatch(data::AbstractString, labels, default...) = onehotbatch([_ for _ in data], labels, default...) # NB function barrier: function _onehotbatch(data, labels) - n_labels = length(labels) - indices = map(x -> _findval(x, labels), data) + indices = [_findval(i, labels) for i in data] if nothing in indices unexpected_values = unique(data[indices .== nothing]) error("Values $unexpected_values are not in labels") end - return OneHotArray(indices, n_labels) + return OneHotArray(indices, length(labels)) end function _onehotbatch(data, labels, default) - n_labels = length(labels) - indices = map(x -> _findval(x, labels), data) - if nothing in indices - default_index = _findval(default, labels) - isnothing(default_index) && error("Default value $default is not in labels") - replaced_indices = replace(indices, nothing => default_index) - return OneHotArray(replaced_indices, n_labels) - else - return OneHotArray(indices, n_labels) - end + default_index = _findval(default, labels) + isnothing(default_index) && error("Default value $default is not in labels") + indices = [isnothing(_findval(i, labels)) ? default_index : _findval(i, labels) for i in data] + return OneHotArray(indices, length(labels)) end """ From 5e5dc7def74ad24a090bef517949360c06144d2f Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Wed, 23 Feb 2022 21:53:40 +0000 Subject: [PATCH 11/14] fixed doctests --- src/onehot.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 15f1b2a443..709e64f9f7 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -187,7 +187,7 @@ onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? # NB function barrier: function _onehotbatch(data, labels) - indices = [_findval(i, labels) for i in data] + indices = UInt32[_findval(i, labels) for i in data] if nothing in indices unexpected_values = unique(data[indices .== nothing]) error("Values $unexpected_values are not in labels") @@ -198,7 +198,7 @@ end function _onehotbatch(data, labels, default) default_index = _findval(default, labels) isnothing(default_index) && error("Default value $default is not in labels") - indices = [isnothing(_findval(i, labels)) ? default_index : _findval(i, labels) for i in data] + indices = UInt32[isnothing(_findval(i, labels)) ? default_index : _findval(i, labels) for i in data] return OneHotArray(indices, length(labels)) end From 55572ab7bce9a179a6e4aa8117b01472568e3aaf Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Thu, 24 Feb 2022 23:08:27 +0000 Subject: [PATCH 12/14] using something instead of calling findval twice; fixing case where findval would return nothing --- src/onehot.jl | 9 ++++----- test/onehot.jl | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 709e64f9f7..2645c4c577 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -185,11 +185,10 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl """ onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) -# NB function barrier: function _onehotbatch(data, labels) - indices = UInt32[_findval(i, labels) for i in data] - if nothing in indices - unexpected_values = unique(data[indices .== nothing]) + indices = UInt32[something(_findval(i, labels), 0) for i in data] + if 0 in indices + unexpected_values = unique(data[indices .== 0]) error("Values $unexpected_values are not in labels") end return OneHotArray(indices, length(labels)) @@ -198,7 +197,7 @@ end function _onehotbatch(data, labels, default) default_index = _findval(default, labels) isnothing(default_index) && error("Default value $default is not in labels") - indices = UInt32[isnothing(_findval(i, labels)) ? default_index : _findval(i, labels) for i in data] + indices = UInt32[something(_findval(i, labels), default_index) for i in data] return OneHotArray(indices, length(labels)) end diff --git a/test/onehot.jl b/test/onehot.jl index 97ac223da5..9a293da8f7 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -24,7 +24,7 @@ using Test @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c)) @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e) @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c), :e) - @test_throws Exception onehotbatch([:a, :e], (:a, :b, :c), :d) + @test_throws Exception onehotbatch([:a, :b], (:a, :b, :c), :d) floats = (0.0, -0.0, NaN, -NaN, Inf, -Inf) @test onecold(onehot(0.0, floats)) == 1 From 6ad6e97ea4ab28fff048e5d5cb4d4491abfb58ea Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Thu, 24 Feb 2022 23:51:29 +0000 Subject: [PATCH 13/14] additional test probably isn't necessary --- test/onehot.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/onehot.jl b/test/onehot.jl index 9a293da8f7..91f64b763d 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -24,7 +24,6 @@ using Test @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c)) @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e) @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c), :e) - @test_throws Exception onehotbatch([:a, :b], (:a, :b, :c), :d) floats = (0.0, -0.0, NaN, -NaN, Inf, -Inf) @test onecold(onehot(0.0, floats)) == 1 From 783fa7d4f109c7160025f44eec47f95b19ac7007 Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Sun, 27 Feb 2022 20:13:52 +0000 Subject: [PATCH 14/14] Make error messaging for unexpected values simpler Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/onehot.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 2645c4c577..c06799caad 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -188,8 +188,9 @@ onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? function _onehotbatch(data, labels) indices = UInt32[something(_findval(i, labels), 0) for i in data] if 0 in indices - unexpected_values = unique(data[indices .== 0]) - error("Values $unexpected_values are not in labels") + for x in data + isnothing(_findval(x, labels)) && error("Value $x not found in labels") + end end return OneHotArray(indices, length(labels)) end