Skip to content

Commit

Permalink
Merge pull request mala-project#466 from dytnvgl/dytnvgl-ddp
Browse files Browse the repository at this point in the history
MALA-DDP
  • Loading branch information
RandomDefaultUser authored May 12, 2024
2 parents 5cfd0c8 + b58c096 commit 7abd9d7
Show file tree
Hide file tree
Showing 14 changed files with 406 additions and 266 deletions.
65 changes: 65 additions & 0 deletions docs/source/advanced_usage/trainingmodel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,68 @@ via
The full path for ``path_to_visualization`` can be accessed via
``trainer.full_visualization_path``.


Training in parallel
********************

If large models or large data sets are employed, training may be slow even
if a GPU is used. In this case, multiple GPUs can be employed with MALA
using the ``DistributedDataParallel`` (DDP) formalism of the ``torch`` library.
To use DDP, make sure you have `NCCL <https://developer.nvidia.com/nccl>`_
installed on your system.

To activate and use DDP in MALA, almost no modification of your training script
is necessary. Simply activate DDP in your ``Parameters`` object. Make sure to
also enable GPU, since parallel training is currently only supported on GPUs.

.. code-block:: python
parameters = mala.Parameters()
parameters.use_gpu = True
parameters.use_ddp = True
MALA is now set up for parallel training. DDP works across multiple compute
nodes on HPC infrastructure as well as on a single machine hosting multiple
GPUs. While essentially no modification of the python script is necessary, some
modifications for calling the python script may be necessary, to ensure
that DDP has all the information it needs for inter/intra-node communication.
This setup *may* differ across machines/clusters. During testing, the
following setup was confirmed to work on an HPC cluster using the
``slurm`` scheduler.

.. code-block:: bash
#SBATCH --nodes=NUMBER_OF_NODES
#SBATCH --ntasks-per-node=NUMBER_OF_TASKS_PER_NODE
#SBATCH --gres=gpu:NUMBER_OF_TASKS_PER_NODE
# Add more arguments as needed
...
# Load more modules as needed
...
# This port can be arbitrarily chosen.
# Given here is the torchrun default
export MASTER_PORT=29500
# Find out the host node.
echo "NODELIST="${SLURM_NODELIST}
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR
# Run using srun.
srun -u bash -c '
# Export additional per process variables
export RANK=$SLURM_PROCID
export LOCAL_RANK=$SLURM_LOCALID
export WORLD_SIZE=$SLURM_NTASKS
python3 -u training.py
'
An overview of environment variables to be set can be found `in the official documentation <https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization>`_.
A general tutorial on DDP itself can be found `here <https://pytorch.org/tutorials/beginner/ddp_series_theory.html>`_.


4 changes: 3 additions & 1 deletion install/mala_gpu_base_environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
name: mala-gpu
channels:
- defaults
- conda-forge
- defaults
dependencies:
- python=3.10
4 changes: 0 additions & 4 deletions mala/common/check_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ def check_modules():
"available": False,
"description": "Enables inference parallelization.",
},
"horovod": {
"available": False,
"description": "Enables training parallelization.",
},
"lammps": {
"available": False,
"description": "Enables descriptor calculation for data preprocessing "
Expand Down
48 changes: 22 additions & 26 deletions mala/common/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

from collections import defaultdict
import platform
import os
import warnings

try:
import horovod.torch as hvd
except ModuleNotFoundError:
pass
import torch
import torch.distributed as dist

use_horovod = False
use_ddp = False
use_mpi = False
comm = None
local_mpi_rank = None
Expand All @@ -33,45 +31,43 @@ def set_current_verbosity(new_value):
current_verbosity = new_value


def set_horovod_status(new_value):
def set_ddp_status(new_value):
"""
Set the horovod status.
Set the ddp status.
By setting the horovod status via this function it can be ensured that
By setting the ddp status via this function it can be ensured that
printing works in parallel. The Parameters class does that for the user.
Parameters
----------
new_value : bool
Value the horovod status has.
Value the ddp status has.
"""
if use_mpi is True and new_value is True:
raise Exception(
"Cannot use horovod and inference-level MPI at "
"the same time yet."
"Cannot use ddp and inference-level MPI at " "the same time yet."
)
global use_horovod
use_horovod = new_value
global use_ddp
use_ddp = new_value


def set_mpi_status(new_value):
"""
Set the MPI status.
By setting the horovod status via this function it can be ensured that
By setting the MPI status via this function it can be ensured that
printing works in parallel. The Parameters class does that for the user.
Parameters
----------
new_value : bool
Value the horovod status has.
Value the MPI status has.
"""
if use_horovod is True and new_value is True:
if use_ddp is True and new_value is True:
raise Exception(
"Cannot use horovod and inference-level MPI at "
"the same time yet."
"Cannot use ddp and inference-level MPI at " "the same time yet."
)
global use_mpi
use_mpi = new_value
Expand Down Expand Up @@ -119,8 +115,8 @@ def get_rank():
The rank of the current thread.
"""
if use_horovod:
return hvd.rank()
if use_ddp:
return dist.get_rank()
if use_mpi:
return comm.Get_rank()
return 0
Expand Down Expand Up @@ -159,8 +155,8 @@ def get_local_rank():
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
if use_horovod:
return hvd.local_rank()
if use_ddp:
return int(os.environ.get("LOCAL_RANK"))
if use_mpi:
global local_mpi_rank
if local_mpi_rank is None:
Expand All @@ -187,8 +183,8 @@ def get_size():
size : int
The number of ranks.
"""
if use_horovod:
return hvd.size()
if use_ddp:
return dist.get_world_size()
if use_mpi:
return comm.Get_size()

Expand All @@ -209,8 +205,8 @@ def get_comm():

def barrier():
"""General interface for a barrier."""
if use_horovod:
hvd.allreduce(torch.tensor(0), name="barrier")
if use_ddp:
dist.barrier()
if use_mpi:
comm.Barrier()
return
Expand Down
Loading

0 comments on commit 7abd9d7

Please sign in to comment.