This project leverages a Long Short-Term Memory (LSTM) network enhanced with an attention mechanism to analyze biological sequences for disease prediction. It is implemented using PyTorch, and includes data visualization through a Dash web application.
- Python 3.6+
- PyTorch
- NumPy
- BioPython
- scikit-learn
- Dash and Plotly for visualization
- Dash Bootstrap Components
You can install the necessary libraries using pip:
pip install torch numpy biopython scikit-learn dash dash-bootstrap-components plotly
- AttentionLayer: A custom neural network layer that applies a softmax function to LSTM outputs to focus on important features.
- LSTMWithAttention: The main model that uses an LSTM layer and the custom AttentionLayer to predict disease risk from sequence data.
- load_sequences_and_labels: Loads sequences from a FASTA file and extracts labels and metadata.
- encode_sequences: Converts nucleotide sequences into numerical format for model input.
- train_model: Handles the training process, including validation loss calculation and optimizer adjustments.
- setup_dash_app: Sets up a Dash application to visualize sequence data and display predictions using an interactive slider and Plotly graphs.
- Load and encode sequence data from a FASTA file.
- Initialize the LSTM with Attention model and configure its parameters.
- Train the model using the encoded data.
- Launch the Dash application to visualize and interact with the sequence data and predictions.
To run the application, ensure you have the example.fasta
file in your working directory, then execute:
if __name__ == "__main__":
main()
This will start a server for the Dash application, allowing you to interact with the sequence data and view predictions in real-time.
Logging and error handling are essential for monitoring the application's health and performance. Ensure you have proper logging configurations to capture and diagnose issues.