-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU evaluation of Recurrence()
broken on Metal
#473
Comments
Can you also post the stacktrace via |
|
Ah I see, it is Metal not handling wrappers properly. CUDA is quite good it this as it doesn't return a wrapper type if the storage is contiguous. For AMD I had to patch it at one point, I will see if that can be handled similarly for Metal. |
Yes, I confirm it works with CUDA. Changing issue title to reflect this. |
Recurrence()
broken?Recurrence()
broken on Metal
As a temporary workaround, you could pass in a vector of matrices instead of the 3D Array. To solve this problem at its core, either:
|
You mean like this? data2 = [data[:,i,:] for i ∈ 1:10]
Lux.apply(m, gpu_device()(data2), gpu_device()(ps), gpu_device()(st)) On CPU they are equivalent. On Metal, it still doesn't work:
|
For upstream issue, worth opening an issue on |
try |
data2 = [data[:,i,:] for i ∈ 1:10]
data2 = data2 .|> gpu_device()
Lux.apply(m, data2, gpu_device()(ps), gpu_device()(st)) still gives:
|
What if we use Metal APIs for RNG and data? m = Recurrence(RNNCell(10 => 10))
ps, st = Lux.setup(Metal.gpuarrays_rng(), m)
data = Metal.rand(10,10,10)
Lux.apply(m, data, ps, st) |
|
We need a dispatch like Lines 9 to 11 in f11c407
|
The patch will be available from the next release |
This works:
but this doesn't:
Stack trace:
The text was updated successfully, but these errors were encountered: