From 8131e4c30ca23206ad047371786db6b6ea0318f0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 19 Feb 2024 03:17:32 +0000 Subject: [PATCH] sync processes on each metric tracking --- Makefile | 6 ++- optimum_benchmark/backends/pytorch/backend.py | 14 +++++-- .../launchers/process/launcher.py | 40 ++++++++++--------- .../launchers/torchrun/launcher.py | 29 ++++++-------- optimum_benchmark/trackers/energy.py | 15 ++++--- optimum_benchmark/trackers/latency.py | 17 ++++---- optimum_benchmark/trackers/memory.py | 8 +++- tests/configs/_base_.yaml | 15 ++----- tests/test_api.py | 10 +++-- 9 files changed, 85 insertions(+), 69 deletions(-) diff --git a/Makefile b/Makefile index bbbed230..0253c183 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,9 @@ # List of targets that are not associated with files -.PHONY: quality style install install_dev_cpu install_dev_gpu, build_docker_cpu, build_docker_cuda, build_docker_rocm, test_cli_cpu_neural_compressor, test_cli_cpu_onnxruntime, test_cli_cpu_openvino, test_cli_cpu_pytorch, test_cli_rocm_pytorch, test_api_cpu, test_api_cuda, test_api_rocm, test_api_misc +.PHONY: quality style install \ + build_docker_cpu, build_docker_cuda, build_docker_rocm, \ + test_cli_cpu_pytorch, test_cli_rocm_pytorch, \ + test_cli_cpu_neural_compressor, test_cli_cpu_onnxruntime, test_cli_cpu_openvino, \ + test_api_cpu, test_api_cuda, test_api_rocm, test_api_misc quality: ruff check . diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 4d2b132c..67034c20 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -9,7 +9,7 @@ from .config import PyTorchConfig from ..peft_utils import get_peft_config_class from ..transformers_utils import randomize_weights -from ...import_utils import is_deepspeed_available, is_peft_available +from ...import_utils import is_deepspeed_available, is_peft_available, is_torch_distributed_available import torch from datasets import Dataset @@ -22,9 +22,13 @@ if is_peft_available(): from peft import get_peft_model # type: ignore +if is_torch_distributed_available(): + import torch.distributed + if is_deepspeed_available(): from deepspeed import init_inference # type: ignore + # disable other loggers datasets_logging.set_verbosity_error() transformers_logging.set_verbosity_error() @@ -277,8 +281,8 @@ def is_awq_quantized(self) -> bool: def is_exllamav2(self) -> bool: return ( self.is_gptq_quantized - and "exllama_config" in self.quantization_config - and self.quantization_config["exllama_config"].get("version", None) == 2 + and hasattr(self.quantization_config, "exllama_config") + and self.quantization_config.exllama_config.get("version", None) == 2 ) @property @@ -356,6 +360,10 @@ def seed(self): torch.cuda.manual_seed_all(self.config.seed) def clean(self) -> None: + if is_torch_distributed_available() and torch.distributed.is_initialized(): + LOGGER.info("\t+ Waiting for distributed processes to finish before cleaning backend") + torch.distributed.barrier() + super().clean() if hasattr(self, "tmpdir"): diff --git a/optimum_benchmark/launchers/process/launcher.py b/optimum_benchmark/launchers/process/launcher.py index 3b9678a3..e852a9b4 100644 --- a/optimum_benchmark/launchers/process/launcher.py +++ b/optimum_benchmark/launchers/process/launcher.py @@ -1,8 +1,5 @@ from typing import Callable -import multiprocessing as mp from logging import getLogger -from multiprocessing import Process, Queue, Lock -from multiprocessing.synchronize import Lock as LockType from ..base import Launcher from .config import ProcessConfig @@ -10,6 +7,7 @@ from ..isolation_utils import device_isolation from ...benchmarks.report import BenchmarkReport +import torch.multiprocessing as mp LOGGER = getLogger("process") @@ -25,28 +23,34 @@ def __init__(self, config: ProcessConfig): mp.set_start_method(self.config.start_method, force=True) def launch(self, worker: Callable, *worker_args) -> BenchmarkReport: - lock = Lock() - queue = Queue(1000) - current_log_level = getLogger().getEffectiveLevel() - worker_process = Process( - target=target, args=(worker, queue, lock, current_log_level, *worker_args), daemon=False - ) + log_level = getLogger().getEffectiveLevel() + + ctx = mp.get_context(self.config.start_method) + queue = ctx.Queue() + lock = ctx.Lock() with device_isolation(enabled=self.config.device_isolation): - worker_process.start() - LOGGER.info(f"\t+ Launched worker process with PID {worker_process.pid}.") - worker_process.join() + process_context = mp.start_processes( + entrypoint, + args=(worker, queue, lock, log_level, *worker_args), + start_method=self.config.start_method, + daemon=False, + join=False, + nprocs=1, + ) + LOGGER.info(f"\t+ Launched worker process(es) with PID(s): {process_context.pids()}") + while not process_context.join(): + pass try: - report = queue.get() + report: BenchmarkReport = queue.get() except EOFError: - LOGGER.error(f"\t+ Worker process exited with code {worker_process.exitcode}, forwarding...") - exit(worker_process.exitcode) + raise RuntimeError("Worker process did not return a report") return report -def target(fn: Callable, queue: Queue, lock: LockType, log_level: str, *args): +def entrypoint(_, worker, queue, lock, log_level, *worker_args): """ This a pickalable function that correctly sets up the logging configuration for the worker process, and puts the output of the worker function into a lock-protected queue. @@ -54,8 +58,8 @@ def target(fn: Callable, queue: Queue, lock: LockType, log_level: str, *args): setup_logging(log_level) - out = fn(*args) + worker_output = worker(*worker_args) lock.acquire() - queue.put(out) + queue.put(worker_output) lock.release() diff --git a/optimum_benchmark/launchers/torchrun/launcher.py b/optimum_benchmark/launchers/torchrun/launcher.py index 30fec120..67cf8281 100644 --- a/optimum_benchmark/launchers/torchrun/launcher.py +++ b/optimum_benchmark/launchers/torchrun/launcher.py @@ -1,6 +1,4 @@ -import multiprocessing as mp from logging import getLogger -from multiprocessing import Queue, Lock from typing import Callable, Dict, Any, List from ..base import Launcher @@ -10,6 +8,7 @@ from ...logging_utils import setup_logging import torch.distributed +import torch.multiprocessing as mp from torch.distributed.elastic.multiprocessing import Std from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.launcher.api import LaunchConfig, launch_agent @@ -29,6 +28,7 @@ def __init__(self, config: TorchrunConfig): mp.set_start_method(self.config.start_method, force=True) def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]: + log_level = getLogger().getEffectiveLevel() launch_config = LaunchConfig( min_nodes=self.config.min_nodes, max_nodes=self.config.max_nodes, @@ -47,9 +47,10 @@ def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]: local_addr=self.config.local_addr, log_dir=self.config.log_dir, ) - lock = Lock() - queue = Queue(1) - log_level = getLogger().getEffectiveLevel() + + ctx = mp.get_context(self.config.start_method) + queue = ctx.Queue() + lock = ctx.Lock() with device_isolation(enabled=self.config.device_isolation): LOGGER.info(f"\t+ Launching torchrun agent with {self.config.nproc_per_node} workers processes") @@ -69,32 +70,28 @@ def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]: else: raise ValueError("No benchmark report was returned by the workers") + setup_logging(level=log_level) report.log() return report @record -def entrypoint(fn, queue, lock, log_level, *args): +def entrypoint(worker, queue, lock, log_level, *worker_args): """ This a pickalable function that correctly sets up the logging configuration """ - if not torch.distributed.is_initialized(): - # initialize the process group if not already initialized - backend = "nccl" if torch.cuda.is_available() else "gloo" - torch.distributed.init_process_group(backend=backend) + torch.distributed.init_process_group() rank = torch.distributed.get_rank() + if rank == 0: + setup_logging(level=log_level, prefix=f"RANK-{rank}") + if torch.cuda.is_available(): torch.cuda.set_device(rank) - if rank == 0: - setup_logging(level=log_level, prefix="RANK-0") - else: - setup_logging(level="ERROR") - - output = fn(*args) + output = worker(*worker_args) lock.acquire() queue.put(output) diff --git a/optimum_benchmark/trackers/energy.py b/optimum_benchmark/trackers/energy.py index 4f2cc5ab..ecc64b8c 100644 --- a/optimum_benchmark/trackers/energy.py +++ b/optimum_benchmark/trackers/energy.py @@ -5,7 +5,10 @@ from typing import Optional, Literal, List from ..system_utils import get_gpu_device_ids -from ..import_utils import is_codecarbon_available +from ..import_utils import is_codecarbon_available, is_torch_distributed_available + +if is_torch_distributed_available(): + import torch.distributed if is_codecarbon_available(): from codecarbon import EmissionsTracker, OfflineEmissionsTracker # type: ignore @@ -78,11 +81,6 @@ def __init__(self, device: str, device_ids: Optional[str] = None): self.device = device self.device_ids = device_ids - self.cpu_energy: float = 0 - self.gpu_energy: float = 0 - self.ram_energy: float = 0 - self.total_energy: float = 0 - if self.device == "cuda": if self.device_ids is None: LOGGER.warning("\t+ `device=cuda` but `device_ids` not provided. Using all available CUDA devices.") @@ -91,6 +89,11 @@ def __init__(self, device: str, device_ids: Optional[str] = None): self.device_ids = list(map(int, self.device_ids.split(","))) LOGGER.info(f"\t+ Tracking GPU energy on devices {self.device_ids}") + if is_torch_distributed_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + self.reset() + def reset(self): self.cpu_energy = 0 self.gpu_energy = 0 diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index ef8c74ce..dda1e104 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -93,21 +93,20 @@ def __init__(self, device: str, backend: str): self.device = device self.backend = backend - self.distributed = is_torch_distributed_available() and torch.distributed.is_initialized() - - self.start_events: List[Union[float, torch.cuda.Event]] = [] - self.end_events: List[Union[float, torch.cuda.Event]] = [] - self.start_time: float = time.perf_counter() - if self.backend == "pytorch" and self.device == "cuda": LOGGER.info("\t+ Tracking Pytorch CUDA latency") else: LOGGER.info("\t+ Tracking CPU latency") + if is_torch_distributed_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + self.reset() + def reset(self): - self.start_time = time.perf_counter() - self.start_events = [] - self.end_events = [] + self.start_events: List[Union[float, torch.cuda.Event]] = [] + self.end_events: List[Union[float, torch.cuda.Event]] = [] + self.start_time: float = time.perf_counter() @contextmanager def track(self): diff --git a/optimum_benchmark/trackers/memory.py b/optimum_benchmark/trackers/memory.py index 9ae456ca..d1544ccf 100644 --- a/optimum_benchmark/trackers/memory.py +++ b/optimum_benchmark/trackers/memory.py @@ -7,7 +7,10 @@ from multiprocessing.connection import Connection from ..system_utils import get_gpu_device_ids, is_nvidia_system, is_rocm_system, get_rocm_version -from ..import_utils import is_pynvml_available, is_amdsmi_available, is_torch_available +from ..import_utils import is_pynvml_available, is_amdsmi_available, is_torch_available, is_torch_distributed_available + +if is_torch_distributed_available(): + import torch.distributed if is_nvidia_system() and is_pynvml_available(): import pynvml @@ -88,6 +91,9 @@ def __init__(self, device: str, backend: str, device_ids: Optional[str] = None): ) LOGGER.info(f"\t+ Tracking Allocated/Reserved memory of {num_pytorch_devices} Pytorch CUDA devices") + if is_torch_distributed_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + self.reset() def reset(self): diff --git a/tests/configs/_base_.yaml b/tests/configs/_base_.yaml index c268095c..3d9ea3fb 100644 --- a/tests/configs/_base_.yaml +++ b/tests/configs/_base_.yaml @@ -2,9 +2,8 @@ defaults: - launcher: process # isolated process launcher - experiment # inheriting experiment schema - _self_ # for hydra 1.1 compatibility - - override hydra/hydra_logging: colorlog # colorful logging - - override hydra/job_logging: colorlog # colorful logging - - override hydra/launcher: joblib # for parallelization + - override hydra/hydra_logging: colorlog + - override hydra/job_logging: colorlog experiment_name: ${device}_${benchmark.name}_${backend.name}_${task} @@ -20,13 +19,5 @@ hydra: # change working directory to the run directory chdir: true env_set: - # set environment variable OVERRIDE_BENCHMARKS to 1 - # to not skip benchmarks that have been run before + # to not skip benchmarks if results already exist OVERRIDE_BENCHMARKS: 1 - - # we are using joblib launcher to parallelize testing since - # we're having ccorrect benchmarks is not important while testing - # to force sequential execution, uncomment the following three lines - # launcher: - # n_jobs: -1 # 1 for debugging - # batch_size: auto # 1 for debugging diff --git a/tests/test_api.py b/tests/test_api.py index e31b8a51..7a34e2dd 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,6 +8,7 @@ from optimum_benchmark.launchers.inline.config import InlineConfig from optimum_benchmark.backends.pytorch.config import PyTorchConfig from optimum_benchmark.launchers.process.config import ProcessConfig +from optimum_benchmark.launchers.torchrun.config import TorchrunConfig from optimum_benchmark.benchmarks.inference.config import INPUT_SHAPES from optimum_benchmark.benchmarks.training.config import DATASET_SHAPES from optimum_benchmark.generators.input_generator import InputGenerator @@ -36,7 +37,11 @@ ("transformers", "image-classification", "google/vit-base-patch16-224"), ("transformers", "semantic-segmentation", "google/vit-base-patch16-224"), ] -LAUNCHER_CONFIGS = [InlineConfig(device_isolation=False), ProcessConfig(device_isolation=False)] +LAUNCHER_CONFIGS = [ + InlineConfig(device_isolation=False), + ProcessConfig(device_isolation=False), + TorchrunConfig(device_isolation=False, nproc_per_node=2), +] BACKENDS = ["pytorch", "none"] DEVICES = ["cpu", "cuda"] @@ -88,8 +93,7 @@ def test_api_memory_tracker(device, backend): else: measured_memory = final_memory.max_vram - initial_memory.max_vram if torch.version.hip is not None: - # skip vram measurement for ROCm - return + return # skip vram measurement for ROCm else: measured_memory = final_memory.max_ram - initial_memory.max_ram