Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN design for efficient CUDNN usage #1365

Closed
jeremiedb opened this issue Oct 18, 2020 · 10 comments
Closed

RNN design for efficient CUDNN usage #1365

jeremiedb opened this issue Oct 18, 2020 · 10 comments
Labels
Milestone

Comments

@jeremiedb
Copy link
Contributor

With the current Flux RNN design, where each data batch is a Vector of length = seq_length and whose elements are of size [features, batch_size], I'm afraid it misses some performance opportunities.

In particular, CUDNN RNN function is designed to directly handle an array of size [features, batch_size, seq_length] and apply the entire RNN chain (vanilla RNN, GRU, LSTM) on the full sequence. Plus, the CUDNN operator supports the stacking of multiple layers, unidirectionnel and bidirectional, as well as a vector specifying sequence length.

The current Flux RNN design goes over the sequence, one timestep as a time, through the m.(x) (or map(m, x)) guidance indicated in the docs. This seems to effectively translates in a single timestep at ta time call to the underlying CUDNN: https://github.com/JuliaGPU/CUDA.jl/blob/fc690e20a90a1211f91d561c3bfc010957381c12/lib/cudnn/rnn.jl#L111, where seq_Length of 1 is hard coded.

Also, a single layer is assumed:

r = CUDNN.RNNDesc{T}(mode, i, h)
, where the layers optional argument is left to its defaut value of 1: (https://github.com/JuliaGPU/CUDA.jl/blob/fc690e20a90a1211f91d561c3bfc010957381c12/lib/cudnn/rnn.jl#L42).

From the approach found in other DL frameworks, for example in MXNet, although a step by step approach is supported, a high performant fused RNN is also available: https://mxnet.apache.org/versions/1.6/api/python/docs/api/ndarray/ndarray.html#mxnet.ndarray.RNN.
Such operator works on data shaped `features X batch_size X seq_length.

It looks like the CUDA.jl /https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/rnn.jl, is almost there to support the arbitrary sequence length, as well as to allow the use of the bidirectionnal, seq_length and clipping options.

To take adavantage of such backend, it seems like moving away from the Vector of sequences to the 3D array representation would be needed. I think it would make sense as it's fairly intuitive to consider each batch as a single array. With such 3D array, using the traditionnal non-fused approach could be done for example through mapslices which wouldn't be a big departure from current thinking:

batch = rand(2,3,4) # features X batch_size X seq_length
rnn = Chain(RNN(2, 5))
m(x) = mapslices(rnn, batch, dims=(1,2))

And a high performance RNN could be accessible with something along those lines:

batch = rand(2,3,4) # features X batch_size X seq_length
m= Chain(FusedRNN(2, 5, layers=3, bidirectionnal=true, ...))

Is such direction for RNN sound? I may well have overlooked considerations particular to Flux, but I think it would be greatly beneficial to bring both:

  • A more regular representation of a batch through a single 3D array
  • Access to common CUDNN operator

@AzamatB @DhairyaLGandhi

@AzamatB
Copy link
Contributor

AzamatB commented Oct 19, 2020

Great write-up @jeremiedb! Would love to see this added to Flux.jl

@pshashk
Copy link
Contributor

pshashk commented Oct 19, 2020

I think, for performance and ease-of-use reasons, it is desirable to have a container type for handling batches of variable-length sequences similar to pytorch PackedSequence or tensorflow RaggedTensor. Otherwise, the user needs to store a vector of lengths and write logic to keep the model from back-propagating loss through the padding. The RNN kernel itself also can't skip unnecessary computation if the data doesn't have the padding information.

@jeremiedb
Copy link
Contributor Author

jeremiedb commented Oct 19, 2020

CUDNN RNN operator does support the handling the varying seq lengths by proviving the a vector of Int of length = batch_size that indicates the valid length of each sequence.
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnGetRNNDataDescriptor
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNPaddingMode_t

In any case, I would expect that at the end, the user will always need to provide the seq_length information in some way.
Is a specific container required for adding such info? During training, if the iterator returns x, seq_length, y rather than the typical x, y, then it seems like the need would be filled. CudnnRNN could be then called with (m:CudnnRnn)(x, seq_length) = ... or just with (m:CudnnRnn)(x) = ... in case where the use_sequence_length=false. Alternatively, with manual RNN cells, the user would likely add a masking layer using the seq_length information.

@pshashk Other than from bringing more high level API usability, are you seeing benefits from a specific RNN container vs adding a seq_length element to the usual data iterator?

@ToucheSir
Copy link
Member

@DrChainsaw
Copy link
Contributor

As much as I like the explicitness of the current approach, I guess that this would also alleviate some ONNX export/import headaches:
DrChainsaw/ONNXNaiveNASflux.jl#12

Would a natural consequence be that one also wants to do broadcasting over more than the batch dimension in Dense to allow them to be stacked after FusedRnns as well? I think I saw a PR about that long ago...

Can backwards compatibility be handled by dispatching on the number of array dimensions?

Perhaps not so nice due to type inference and surprising interop with other layers, but I guess a sentinel value like missing or Zero could also be used to derive the seq_length vector without any need for extra input.

@pshashk
Copy link
Contributor

pshashk commented Oct 20, 2020

@pshashk Other than from bringing more high level API usability, are you seeing benefits from a specific RNN container vs adding a seq_length element to the usual data iterator?

One advantage is the option to implement padding-aware methods for other layers. Convolution, normalization, and pooling applied after RNN may introduce the dependence between weights and the amount of padding, which is rarely desirable. I'm not sure if there is enough demand for that to live in Flux and not some complimentary package.

@DhairyaLGandhi
Copy link
Member

Making a clean transition to a 3D array is fine but have we checked for correctness of the gradients? Further, it was done this way to match cudnn style, can we verify the performance considerations ? We should be still be able to use the single time step API to use cudnn if it works better, but will have to think about whether this is sound.

@jeremiedb
Copy link
Contributor Author

@DhairyaLGandhi Maybe there's a confusion, but I'm fairly confident about the CUDNN approach being 3D based, as discussed here and also evidenced from the CUDA.jl rnn api: https://github.com/JuliaGPU/CUDA.jl/blob/fc690e20a90a1211f91d561c3bfc010957381c12/lib/cudnn/rnn.jl#L120 or directly at CUDNN API doc

My original concern came from what appeared to be unreliable support of RNNs on GPUs and thought that starting off a known reference such as the 3D CUDNN could help laying out a more robust/performant foundation. It seems though that Knet did managed to get performing RNN based on the iterative sequence approach. Also, here JuliaGPU/CUDA.jl#343, Knet's author discussed having a more extensive CUDNN functions support.

I guess best scenario would be to have the option to get vanilla RNNs through 3D and CUDNN dispatch, while keeping the flexibility of the 2D approach for adhoc stuff.

But first priority I think should be to have functional RNNs on GPUs, That's what I tried to explore in above PR, by keeping the original 2D design and dropping cudnn.

bors bot added a commit that referenced this issue Nov 7, 2020
1367: RNN update to drop CUDNN, fix LSTM bug and output type stability r=CarloLucibello a=jeremiedb

PR related to #1114 #1360 #1365 

Some experiment for RNN handling. 

Hidden state of each cell structure was dropped as they weren't needed (AFAIK, only needed for size inference for CUDNN, but bias size could be used as a substitute to cells' `h` there as well). 

Looked to drop dependence on CUDNN entirely, so it's a pure Flux/CUDA.jl. File `src/cuda/curnnjl` no longer used. No  modifications were made to the cell computations. Initial test seems to show decent performance, but yet to benchmark. 

Pending issue: despite having dropped completely the CUDNN dependency, there's still an instability issue that seems present when running on GPU. This is illustrated in the test at lines 1-50 of file `test\rnn-test-jdb.jl`. If that test runs on CPU, it goes well thorugh the 100 iterations. However, the same on GPU will thow NAs after couple dozens of iterations. 
My only hypothesis so far: when performing the iteration over the sequence through `m.(x)` or `map(rnn, x)`, is the order of the execution safe? Ie: is it possible that there isn't a `sync()` on the CUDA side between those seq steps, which may mess up the state?

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [ ] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: jeremiedb <[email protected]>
Co-authored-by: jeremie.db <[email protected]>
@CarloLucibello CarloLucibello added this to the v0.15 milestone Oct 14, 2024
@CarloLucibello
Copy link
Member

Given the fact that we have severe correctness issues with explicit differentiation (#2185) due to the fact that Recur performs an internal mutation of the state, I think it would be good to resume this redesign.

@CarloLucibello CarloLucibello mentioned this issue Oct 14, 2024
13 tasks
@CarloLucibello
Copy link
Member

With #2500 RNN, GRU and LSTM take feature x seq_len x batch_size inputs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants