-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bump compat for Metalhead #232
Conversation
Failing on GPU only. The complaint is about scalar indexing. I've spent some time on this today but this is hard for me to debug because I don't currently have GPU access. I conjecture that the following code is failing on a GPU but not a CPU, and this contains the issue. Be good if someone can confirm this indeed fails. And if so, where is the scalar indexing? import Flux
import MLJFlux
import StableRNGs.StableRNG
rng = StableRNG(123)
X, y = MLJFlux.make_images(rng);
typeof(X)
# Vector{Matrix{Gray{Float64}}}
data = MLJFlux.collate(ImageClassifier(), X, y);
Flux.gpu(data) # no effect on my CPU-only machine
typeof(data)
# Tuple{Vector{Array{Float32, 4}}, Vector{OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}}
n_channels = 1
n_classes = 3
init = Flux.glorot_uniform(rng)
chain = Flux.Chain(
Flux.Conv((2, 2), n_channels=>2, init=init),
Flux.Conv((2, 2), 2=>1, init=init),
x->reshape(x, :, size(x)[end]),
Flux.Dense(16, n_classes, init=init))
x = data[1][1]
typeof(x)
# Array{Float32, 4}
sizeof(x)
# (6, 6, 1, 1)
chain(x) |
Okay. I guess the But these docs say that |
Okay, looks like the colon is not supported in the |
@ablaom would you also be open to making Metalhead an optional dependency? |
function make2d(x) | ||
l = length(x) | ||
b = size(x)[end] | ||
reshape(x, div(l, b), b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I thought this use of reshape
would avoid the "scalar indexing error", but tests are still not passing.
That's tricky because the default builder for ImageClassifier is a VGG architecture from Metalhead.jl. We could throw an error if An alternative, which already looks too complicated to me, is to make the default
Related discussion: #162 Maybe you have a better idea? |
Not a better idea but moving ImageClassifier itself out to another package is another option. Not a great option though if you want all Flux-related wrappers to be in this repo. |
That seems strange to me, because it's basically what I had a look through the failing CI runs, and the problem is instead this warning: https://buildkite.com/julialang/mljflux-dot-jl/builds/339#018a25d7-e2c7-4c2b-a7af-c1a9d97436c4/425-777. Because we switched to package extensions in Flux 0.14, cuDNN needs to be separately added to an environment to enable the CUDA conv routines in NNlib. I'm guessing MLJFlux doesn't want to take it on as a dep, so adding it into your test env/extras should be enough. |
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## dev #232 +/- ##
==========================================
- Coverage 93.26% 92.06% -1.21%
==========================================
Files 12 12
Lines 312 315 +3
==========================================
- Hits 291 290 -1
- Misses 21 25 +4
☔ View full report in Codecov by Sentry. |
Thanks indeed for the help @ToucheSir. I guess that scalar indexing error was a Red Herring. I've added |
This PR bumps the [compat] for Metalhead to "0.8" and addresses resulting breakages.
Replaces #226