Skip to content

Commit

Permalink
Merge branch 'hsdp-ci-tests' of github.com-regular:mosaicml/composer …
Browse files Browse the repository at this point in the history
…into hsdp-ci-tests
  • Loading branch information
v-chen_data committed Jun 20, 2024
2 parents eda5ede + 4b55781 commit f8f1145
Show file tree
Hide file tree
Showing 29 changed files with 808 additions and 261 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-3.11-2.3
container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-doctest
Expand Down
4 changes: 2 additions & 2 deletions composer/algorithms/augmix/augmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def _augmix_pil_image(
aug = np.random.choice(augmentation_set)
augmented_image = aug(augmented_image, severity)
augmented_combination += chain_weights[chain_i] * np.asarray(augmented_image)
mixed = (1 - mixing_weight) * np.asarray(img_pil) + mixing_weight * augmented_combination
mixed = Image.fromarray(np.uint8(mixed))
mixed = (1 - mixing_weight) * np.asarray(img_pil, dtype=np.float32) + mixing_weight * augmented_combination
mixed = Image.fromarray(np.uint8(mixed)) # type: ignore
return mixed

f_pil = functools.partial(
Expand Down
31 changes: 26 additions & 5 deletions composer/algorithms/utils/augmentation_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import numpy as np
from PIL import Image, ImageEnhance, ImageOps
from PIL.Image import Resampling, Transform

AugmentationFn = Callable[[Image.Image, float], Image.Image]

Expand Down Expand Up @@ -155,7 +156,7 @@ def rotate(pil_img: Image.Image, level: float):
degrees = _int_parameter(_sample_level(level), 30)
if np.random.uniform() > 0.5:
degrees = -degrees
return pil_img.rotate(degrees, resample=Image.BILINEAR)
return pil_img.rotate(degrees, resample=Resampling.BILINEAR)


def solarize(pil_img: Image.Image, level: float):
Expand Down Expand Up @@ -183,7 +184,12 @@ def shear_x(pil_img: Image.Image, level: float):
level = _float_parameter(_sample_level(level), 0.3)
if np.random.uniform() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR)
return pil_img.transform(
pil_img.size,
Transform.AFFINE,
(1, level, 0, 0, 1, 0),
resample=Resampling.BILINEAR,
)


def shear_y(pil_img: Image.Image, level: float):
Expand All @@ -197,7 +203,12 @@ def shear_y(pil_img: Image.Image, level: float):
level = _float_parameter(_sample_level(level), 0.3)
if np.random.uniform() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR)
return pil_img.transform(
pil_img.size,
Transform.AFFINE,
(1, 0, 0, level, 1, 0),
resample=Resampling.BILINEAR,
)


def translate_x(pil_img: Image.Image, level: float):
Expand All @@ -211,7 +222,12 @@ def translate_x(pil_img: Image.Image, level: float):
level = _int_parameter(_sample_level(level), pil_img.size[0] / 3)
if np.random.random() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR)
return pil_img.transform(
pil_img.size,
Transform.AFFINE,
(1, 0, level, 0, 1, 0),
resample=Resampling.BILINEAR,
)


def translate_y(pil_img: Image.Image, level: float):
Expand All @@ -225,7 +241,12 @@ def translate_y(pil_img: Image.Image, level: float):
level = _int_parameter(_sample_level(level), pil_img.size[1] / 3)
if np.random.random() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR)
return pil_img.transform(
pil_img.size,
Transform.AFFINE,
(1, 0, 0, 0, 1, level),
resample=Resampling.BILINEAR,
)


# The following augmentations overlap with corruptions in the ImageNet-C/CIFAR10-C test
Expand Down
93 changes: 79 additions & 14 deletions composer/callbacks/system_metrics_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os

import psutil
import torch

from composer.core import Callback, Event, State
from composer.loggers import Logger
Expand All @@ -19,13 +20,52 @@

__all__ = ['SystemMetricsMonitor']

_GPU_METRICS = [
'gpu_percentage',
'memory_percentage',
'gpu_temperature_C',
'gpu_power_usage_W',
]


class SystemMetricsMonitor(Callback):
"""Track system metrics."""
"""Logs GPU/CPU metrics.
GPU Metrics:
gpu_percentage: Occupancy rate, percent of time over sampling period during which one or more kernels was executing on the GPU.
memory_percentage: Percent of time over sampling period during which global memory was being read or written.
gpu_temperature_C: Temperature of device, in Celcius.
gpu_power_usage_W: Power usage of device, in Watts.
By default, only the maximum and minimum values for these metrics, alongside their respective ranks in the key names,
are logged on the :attr:`.Event.BATCH_START`, :attr:`.Event.EVAL_BATCH_START`, :attr:`.Event.PREDICT_BATCH_START`
events for every batch. If log_all_data is set to True, all values for these metrics across all ranks are logged on the
above events for every batch.
Example:
.. doctest::
def __init__(self, gpu_available: bool = False) -> None:
>>> from composer import Trainer
>>> from composer.callbacks import SystemMetricsMonitor
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration='1ep',
... callbacks=[SystemMetricsMonitor()],
... )
Args:
log_all_data (bool, optional): True if user wants to log data for all ranks, not just the min/max.
Defaults to False.
"""

def __init__(self, log_all_data: bool = False) -> None:
super().__init__()
self.gpu_available = gpu_available
self.gpu_available = torch.cuda.is_available()
self.log_all_data = log_all_data
if self.gpu_available:
try:
import pynvml
Expand All @@ -46,9 +86,23 @@ def run_event(self, event: Event, state: State, logger: Logger):
]:
local_node_system_metrics = self.compute_system_metrics()
all_system_metrics = dist.all_gather_object(local_node_system_metrics)
system_metrics = {
key: value for local_metrics in all_system_metrics for key, value in local_metrics.items()
}
system_metrics = {}

if self.log_all_data:
for rank, metrics in enumerate(all_system_metrics):
for key, value in metrics.items():
if key in _GPU_METRICS:
system_metrics[f'{key}_rank_{rank}'] = value
else:
system_metrics[key] = value

else:
system_metrics = self.compute_gpu_min_max_metrics(all_system_metrics, state)
for rank, metrics in enumerate(all_system_metrics):
for key, value in metrics.items():
if key not in _GPU_METRICS:
system_metrics[key] = value

logger.log_metrics(system_metrics)

def compute_system_metrics(self):
Expand All @@ -58,17 +112,14 @@ def compute_system_metrics(self):
if self.gpu_available:
import pynvml
local_rank = dist.get_local_rank()
global_rank = dist.get_global_rank()
handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank)
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
system_metrics[f'device{global_rank}_memory_total'] = memory.total
system_metrics[f'device{global_rank}_memory_free'] = memory.free
system_metrics[f'device{global_rank}_memory_used'] = memory.used
device_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
system_metrics[f'device{global_rank}_gpu_percentage'] = device_utilization.gpu
system_metrics[f'device{global_rank}_memory_percentage'] = device_utilization.memory
system_metrics['gpu_percentage'] = device_utilization.gpu
system_metrics['memory_percentage'] = device_utilization.memory
temperature = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
system_metrics[f'device{global_rank}_gpu_temperature'] = temperature
system_metrics['gpu_temperature_C'] = temperature
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 # convert from mW to W
system_metrics['gpu_power_usage_W'] = power

# Get metrics for the system
cpu_percent = psutil.cpu_percent()
Expand All @@ -83,3 +134,17 @@ def compute_system_metrics(self):
for k, v in network_usage.items():
system_metrics[f'network_{k}'] = v
return system_metrics

def compute_gpu_min_max_metrics(self, all_metrics, state):
min_max_metrics = {}

if self.gpu_available:
for key in _GPU_METRICS:
values = torch.tensor([metrics_for_cur_rank[key] for metrics_for_cur_rank in all_metrics])
values = state.device.tensor_to_device(values)
min_rank = int(torch.argmin(values).item())
max_rank = int(torch.argmax(values).item())
min_max_metrics[f'min_{key}_rank_{min_rank}'] = values[min_rank].item()
min_max_metrics[f'max_{key}_rank_{max_rank}'] = values[max_rank].item()

return min_max_metrics
145 changes: 145 additions & 0 deletions composer/checkpoint/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Useful functions for saving state dicts to disk."""

import logging
import os
import textwrap
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Union

import torch
import torch.distributed.checkpoint as DCP
from packaging import version
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor

from composer.utils import dist
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file

log = logging.getLogger(__name__)


def save_state_dict_to_disk(
state_dict: Dict[str, Any],
destination_file_path: str,
overwrite: bool = False,
save_format: str = 'pt', # or hf, safetensor
) -> Optional[str]:
"""Saves a state dict to local disk.
Args:
state_dict (Dict[str,Any]): The state dict to save.
destination_file_path (str): The path to save the state dict to. If sharded,
this should be the pth to a directory. Otherwise, it should be a path to a file.
overwrite (bool): If True, the file will be overwritten if it exists.
save_format (str): The format to save the state dict in. One of 'pt', 'hf', or 'safetensor'.
Returns:
str: The full path to the saved state dict if (sharded is false and rank 0) or if sharded is true, otherwise None.
"""
if state_dict == {}:
return None
if is_state_dict_sharded(state_dict):
path_saved = _save_sharded_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format)
else:
if dist.get_global_rank() == 0:
path_saved = _save_full_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format)
else:
path_saved = None

return path_saved


def _save_sharded_state_dict_to_disk(
state_dict: Dict[str, Any],
destination_file_path: str,
overwrite: bool = False,
save_format: str = 'pt',
) -> Optional[str]:

if save_format != 'pt':
raise NotImplementedError(
f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
)

if state_dict == {}:
return None

# If user specifies filename instead of directory suffixes, strip them and warn
if len(Path(destination_file_path).suffixes) > 0:
stripped_path = _strip_suffixes(destination_file_path)
warnings.warn(
textwrap.dedent(
f"""Sharded checkpoints require a directory path not a file path:
{destination_file_path} will have its extensions stripped and checkpoints will be saved in {stripped_path}
as {stripped_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}""",
),
)
destination_file_path = stripped_path

if dist.get_global_rank() == 0 and not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'Directory {destination_file_path} already exists. Set overwrite=True to overwrite it.')

log.debug(
f'Starting saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}',
)

# For 2.3.0 and above you can use checkpoint_id, but this version works the best for all versions
# of torch (and makes pyright happier) that we support, so we use it for now.
if version.parse(torch.__version__) < version.parse('2.2.0'):
DCP.save_state_dict(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))
else:
DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))

return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME


def _save_full_state_dict_to_disk(
state_dict: Dict[str, Any],
destination_file_path: str,
overwrite: bool = False,
save_format: str = 'pt', # or hf, safetensor
) -> Optional[str]:

if save_format != 'pt':
raise NotImplementedError(
f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
)

if not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'File {destination_file_path} already exists. Set overwrite=True to overwrite it.')

if dist.get_global_rank() == 0:
_write_checkpoint_file(state_dict=state_dict, filename=destination_file_path)
return destination_file_path
return None


def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool:
"""Determines if the state dict is sharded.
Args:
state_dict (Dict[str, Any]): The state dict to check.
Returns:
bool: Whether the state dict is sharded.
"""
for value in state_dict.values():
if isinstance(value, ShardedTensor) or isinstance(value, DTensor):
return True
if isinstance(value, Dict):
is_sharded = is_state_dict_sharded(value)
if is_sharded:
return True
return False


def _strip_suffixes(path: Union[str, Path]) -> str:
path = Path(path)
for _ in path.suffixes:
path = path.with_suffix('')

return str(path)
12 changes: 7 additions & 5 deletions composer/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,13 @@ def _parse_args():
if args.nproc < 1:
raise ValueError('The nproc must be 1 or greater')

if args.world_size is None and 'WORLD_SIZE' in os.environ:
args.world_size = int(os.environ['WORLD_SIZE'])
if args.world_size is None:
if 'WORLD_SIZE' in os.environ and os.environ.get('LOCAL_WORLD_SIZE') != os.environ['WORLD_SIZE']:
# Use WORLD_SIZE env var if set and running multinode. Otherwise, default to nproc
# to enable easy overriding of number of processes when on a single node.
args.world_size = int(os.environ['WORLD_SIZE'])
else:
args.world_size = args.nproc

if args.base_rank is None and 'BASE_RANK' in os.environ:
args.base_rank = int(os.environ['BASE_RANK'])
Expand All @@ -212,9 +217,6 @@ def _parse_args():
if args.master_port is None and 'MASTER_PORT' in os.environ:
args.master_port = int(os.environ['MASTER_PORT'])

if args.world_size is None:
args.world_size = args.nproc

if args.world_size < args.nproc:
raise ValueError(f'world_size({args.world_size}) cannot be less than nproc({args.nproc})')

Expand Down
Loading

0 comments on commit f8f1145

Please sign in to comment.