This repository contains the official TensorFlow implementation of the following paper:
Residual Shuffle-Exchange Networks for Fast Processing of Long Sequences
by Andis Draguns, Emīls Ozoliņš, Agris Šostaks, Matīss Apinis, Kārlis Freivalds
Abstract: Attention is a commonly used mechanism in sequence processing, but it is of O(n²) complexity which prevents its application to long sequences. The recently introduced neural Shuffle-Exchange network offers a computation-efficient alternative, enabling the modelling of long-range dependencies in O(n log n) time. The model, however, is quite complex, involving a sophisticated gating mechanism derived from the Gated Recurrent Unit.
In this paper, we present a simple and lightweight variant of the Shuffle-Exchange network, which is based on a residual network employing GELU and Layer Normalization. The proposed architecture not only scales to longer sequences but also converges faster and provides better accuracy. It surpasses the Shuffle-Exchange network on the LAMBADA language modelling task and achieves state-of-the-art performance on the MusicNet dataset for music transcription while being efficient in the number of parameters.
We show how to combine the improved Shuffle-Exchange network with convolutional layers, establishing it as a useful building block in long sequence processing applications.
Residual Shuffle-Exchange networks are a simpler and faster replacement for the recently proposed Neural Shuffle-Exchange network architecture. It has O(n log n) complexity and enables processing of sequences up to a length of 2 million symbols where standard methods fail (e.g., attention mechanisms). The Residual Shuffle-Exchange network can serve as a useful building block for long sequence processing applications.
Click the gif to see the full video on YouTube:
Our paper describes Residual Shuffle-Exchange networks in detail and provides full results on long binary addition, long binary multiplication, sorting tasks, the LAMBADA question answering task and multi-instrument musical note recognition using the MusicNet dataset.
Here are the accuracy results on the MusicNet transcription task of identifying the musical notes performed from audio waveforms (freely-licensed classical music recordings):
Model | Learnable parameters (M) | Average precision score (%) |
---|---|---|
cgRNN | 2.36 | 53.0 |
Deep Real Network | 10.0 | 69.8 |
Deep Complex Network | 8.8 | 72.9 |
Complex Transformer | 11.61 | 74.22 |
Translation-invariant net | unknown | 77.3 |
Residual Shuffle-Exchange network | 3.06 | 78.02 |
Note: Our used model achieves state-of-the-art performance while being efficient in the number of parameters using the audio waveform directly compared to the previous state-of-the-art models that used specialised architectures with complex number representations of the Fourier-transformed waveform.
Here are the accuracy results on the LAMBADA question answering task of predicting a target word in its broader context (on average 4.6 sentences picked from novels):
Model | Learnable parameters (M) | Test accuracy (%) |
---|---|---|
Random word from passage | - | 1.6 |
Gated-Attention Reader | unknown | 49.0 |
Neural Shuffle-Exchange network | 33 | 52.28 |
Residual Shuffle-Exchange network | 11 | 54.34 |
Universal Transformer | 152 | 56.0 |
Human performance | - | 86.0 |
GPT-3 | 175000 | 86.4 |
Note: Our used model works faster and can be evaluated on 4 times longer sequences using the same amount of GPU memory compared to the Shuffle-Exchange network model and on 128 times longer sequences than the Universal Transformer model.
Residual Shuffle-Exchange networks are a lightweight variant of the continuous, differentiable neural networks with a regular-layered structure consisting of alternating Switch and Shuffle layers that are Shuffle-Exchange networks.
The Switch Layer divides the input into adjacent pairs of values and applies a Residual Switch Unit, a learnable 2-to-2 function, to each pair of inputs producing two outputs, employing GELU and Layer Normalization.
Here is an illustration of a Residual Switch Unit, which replaces the Switch Unit from Shuffle-Exchange networks:
The Shuffle Layer follows where inputs are permuted according to a perfect-shuffle permutation (i.e., how a deck of cards is shuffled by splitting it into halves and then interleaving them) – a cyclic bit shift rotating left in the first part of the network and (inversely) rotating right in the second part.
The Residual Shuffle-Exchange network is organized in blocks by alternating these two kinds of layers in the pattern of the Beneš network. Such a network can represent a wide class of functions including any permutation of the input values.
Here is an illustration of a whole Residual Shuffle-Exchange network model consisting of two blocks with 8 inputs:
Running the experiments requires the dependencies to be installed and the following system requirements.
- Python 3.6 or higher.
- TensorFlow 1.14.0.
To start training the Residual Shuffle-Exchange network, run the terminal command:
python3 trainer.py
By default it will train on the music transcription task. To select the sequence processing task for which to train the Residual Shuffle-Exchange network, edit the config.py
file that contains various hyperparameter and setting options.
For the MusicNet music transcription task, make sure that the corresponding settings in config.py
are uncommented:
"""Recommended settings for MusicNet"""
# task = "musicnet"
# n_Benes_blocks = 2 # depth of the model
...
To train the model on the MusicNet dataset, the dataset has to be downloaded and parsed - that can be done by running:
python3 musicnet_data/get_musicnet.py
python3 musicnet_data/parse_file.py
This might take a while. If you run out of RAM (it can take more than 40GB), you can download musicnet.npz
from Kaggle and place it in the musicnet_data
directory.
If you have enough RAM to load the entire dataset (can be more than 128GB), set musicnet_subset
to False
for faster training. Increasing musicnet_window_size
requires more RAM and trains slower but produces greater accuracy.
To use a pretrained model for music transcription, place the contents of trained_model_m8192F1
in the out_dir
directory specified in the config.py
file.
To test the trained model for the MusicNet task on the test set, run tester.py
. To transcribe a custom wav file to MIDI, place the file in the musicnet_data
directory and run:
python3 transcribe.py yourwavfile.wav
For the LAMBADA question answering task uncomment the corresponding settings in config.py
:
"""Recommended settings for lambada"""
# task = "lambada"
# n_input = lambada_vocab_size
...
To download the LAMBADA dataset see the original publication by Paperno et al.
To download the pre-trained fastText 1M English word embedding see the downloads section of the FastText library website and extract to directory listed in the config.py
file variable base_folder
under “Embedding configuration”:
"""Embedding configuration"""
use_pre_trained_embedding = False
base_folder = "/host-dir/embeddings/"
embedding_file = base_folder + "fast_word_embedding.vec"
emb_vector_file = base_folder + "emb_vectors.bin"
emb_word_dictionary = base_folder + "word_dict.bin"
...
To enable the pre-trained embedding change the config.py
file variable use_pre_trained_embedding
to True
.
If you are running Windows, before starting training the Residual Shuffle-Exchange network edit the config.py
file to change the directory-related variables to Windows file path format in the following way:
...
"""Local storage (checkpoints, etc)"""
...
out_dir = ".\host-dir\gpu" + gpu_instance
model_file = out_dir + "\\varWeights.ckpt"
image_path = out_dir + "\\images"
...
If you are doing music transcription on Windows, directory-related variables in files related to MusicNet would need to be changed in a similar manner.
If you use Residual Shuffle-Exchange networks, please use the following BibTeX entry when citing the paper:
@inproceedings{draguns2021residual,
title={Residual Shuffle-Exchange Networks for Fast Processing of Long Sequences},
author={Draguns, Andis and Ozoli{\c{n}}{\v{s}}, Em{\=\i}ls and {\v{S}}ostaks, Agris and Apinis, Mat{\=\i}ss and Freivalds, Karlis},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={35},
number={8},
pages={7245--7253},
year={2021}
}
For help or issues using Residual Shuffle-Exchange networks, please submit a GitHub issue.
For personal communication related to Residual Shuffle-Exchange networks, please contact Kārlis Freivalds ([email protected]).