Skip to content

Commit

Permalink
sync processes on each metric tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 19, 2024
1 parent a50f072 commit 8131e4c
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 69 deletions.
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# List of targets that are not associated with files
.PHONY: quality style install install_dev_cpu install_dev_gpu, build_docker_cpu, build_docker_cuda, build_docker_rocm, test_cli_cpu_neural_compressor, test_cli_cpu_onnxruntime, test_cli_cpu_openvino, test_cli_cpu_pytorch, test_cli_rocm_pytorch, test_api_cpu, test_api_cuda, test_api_rocm, test_api_misc
.PHONY: quality style install \
build_docker_cpu, build_docker_cuda, build_docker_rocm, \
test_cli_cpu_pytorch, test_cli_rocm_pytorch, \
test_cli_cpu_neural_compressor, test_cli_cpu_onnxruntime, test_cli_cpu_openvino, \
test_api_cpu, test_api_cuda, test_api_rocm, test_api_misc

quality:
ruff check .
Expand Down
14 changes: 11 additions & 3 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .config import PyTorchConfig
from ..peft_utils import get_peft_config_class
from ..transformers_utils import randomize_weights
from ...import_utils import is_deepspeed_available, is_peft_available
from ...import_utils import is_deepspeed_available, is_peft_available, is_torch_distributed_available

import torch
from datasets import Dataset
Expand All @@ -22,9 +22,13 @@
if is_peft_available():
from peft import get_peft_model # type: ignore

if is_torch_distributed_available():
import torch.distributed

if is_deepspeed_available():
from deepspeed import init_inference # type: ignore


# disable other loggers
datasets_logging.set_verbosity_error()
transformers_logging.set_verbosity_error()
Expand Down Expand Up @@ -277,8 +281,8 @@ def is_awq_quantized(self) -> bool:
def is_exllamav2(self) -> bool:
return (
self.is_gptq_quantized
and "exllama_config" in self.quantization_config
and self.quantization_config["exllama_config"].get("version", None) == 2
and hasattr(self.quantization_config, "exllama_config")
and self.quantization_config.exllama_config.get("version", None) == 2
)

@property
Expand Down Expand Up @@ -356,6 +360,10 @@ def seed(self):
torch.cuda.manual_seed_all(self.config.seed)

def clean(self) -> None:
if is_torch_distributed_available() and torch.distributed.is_initialized():
LOGGER.info("\t+ Waiting for distributed processes to finish before cleaning backend")
torch.distributed.barrier()

super().clean()

if hasattr(self, "tmpdir"):
Expand Down
40 changes: 22 additions & 18 deletions optimum_benchmark/launchers/process/launcher.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from typing import Callable
import multiprocessing as mp
from logging import getLogger
from multiprocessing import Process, Queue, Lock
from multiprocessing.synchronize import Lock as LockType

from ..base import Launcher
from .config import ProcessConfig
from ...logging_utils import setup_logging
from ..isolation_utils import device_isolation
from ...benchmarks.report import BenchmarkReport

import torch.multiprocessing as mp

LOGGER = getLogger("process")

Expand All @@ -25,37 +23,43 @@ def __init__(self, config: ProcessConfig):
mp.set_start_method(self.config.start_method, force=True)

def launch(self, worker: Callable, *worker_args) -> BenchmarkReport:
lock = Lock()
queue = Queue(1000)
current_log_level = getLogger().getEffectiveLevel()
worker_process = Process(
target=target, args=(worker, queue, lock, current_log_level, *worker_args), daemon=False
)
log_level = getLogger().getEffectiveLevel()

ctx = mp.get_context(self.config.start_method)
queue = ctx.Queue()
lock = ctx.Lock()

with device_isolation(enabled=self.config.device_isolation):
worker_process.start()
LOGGER.info(f"\t+ Launched worker process with PID {worker_process.pid}.")
worker_process.join()
process_context = mp.start_processes(
entrypoint,
args=(worker, queue, lock, log_level, *worker_args),
start_method=self.config.start_method,
daemon=False,
join=False,
nprocs=1,
)
LOGGER.info(f"\t+ Launched worker process(es) with PID(s): {process_context.pids()}")
while not process_context.join():
pass

try:
report = queue.get()
report: BenchmarkReport = queue.get()
except EOFError:
LOGGER.error(f"\t+ Worker process exited with code {worker_process.exitcode}, forwarding...")
exit(worker_process.exitcode)
raise RuntimeError("Worker process did not return a report")

return report


def target(fn: Callable, queue: Queue, lock: LockType, log_level: str, *args):
def entrypoint(_, worker, queue, lock, log_level, *worker_args):
"""
This a pickalable function that correctly sets up the logging configuration for the worker process,
and puts the output of the worker function into a lock-protected queue.
"""

setup_logging(log_level)

out = fn(*args)
worker_output = worker(*worker_args)

lock.acquire()
queue.put(out)
queue.put(worker_output)
lock.release()
29 changes: 13 additions & 16 deletions optimum_benchmark/launchers/torchrun/launcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import multiprocessing as mp
from logging import getLogger
from multiprocessing import Queue, Lock
from typing import Callable, Dict, Any, List

from ..base import Launcher
Expand All @@ -10,6 +8,7 @@
from ...logging_utils import setup_logging

import torch.distributed
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.launcher.api import LaunchConfig, launch_agent
Expand All @@ -29,6 +28,7 @@ def __init__(self, config: TorchrunConfig):
mp.set_start_method(self.config.start_method, force=True)

def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]:
log_level = getLogger().getEffectiveLevel()
launch_config = LaunchConfig(
min_nodes=self.config.min_nodes,
max_nodes=self.config.max_nodes,
Expand All @@ -47,9 +47,10 @@ def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]:
local_addr=self.config.local_addr,
log_dir=self.config.log_dir,
)
lock = Lock()
queue = Queue(1)
log_level = getLogger().getEffectiveLevel()

ctx = mp.get_context(self.config.start_method)
queue = ctx.Queue()
lock = ctx.Lock()

with device_isolation(enabled=self.config.device_isolation):
LOGGER.info(f"\t+ Launching torchrun agent with {self.config.nproc_per_node} workers processes")
Expand All @@ -69,32 +70,28 @@ def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]:
else:
raise ValueError("No benchmark report was returned by the workers")

setup_logging(level=log_level)
report.log()

return report


@record
def entrypoint(fn, queue, lock, log_level, *args):
def entrypoint(worker, queue, lock, log_level, *worker_args):
"""
This a pickalable function that correctly sets up the logging configuration
"""
if not torch.distributed.is_initialized():
# initialize the process group if not already initialized
backend = "nccl" if torch.cuda.is_available() else "gloo"
torch.distributed.init_process_group(backend=backend)
torch.distributed.init_process_group()

rank = torch.distributed.get_rank()

if rank == 0:
setup_logging(level=log_level, prefix=f"RANK-{rank}")

if torch.cuda.is_available():
torch.cuda.set_device(rank)

if rank == 0:
setup_logging(level=log_level, prefix="RANK-0")
else:
setup_logging(level="ERROR")

output = fn(*args)
output = worker(*worker_args)

lock.acquire()
queue.put(output)
Expand Down
15 changes: 9 additions & 6 deletions optimum_benchmark/trackers/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from typing import Optional, Literal, List

from ..system_utils import get_gpu_device_ids
from ..import_utils import is_codecarbon_available
from ..import_utils import is_codecarbon_available, is_torch_distributed_available

if is_torch_distributed_available():
import torch.distributed

if is_codecarbon_available():
from codecarbon import EmissionsTracker, OfflineEmissionsTracker # type: ignore
Expand Down Expand Up @@ -78,11 +81,6 @@ def __init__(self, device: str, device_ids: Optional[str] = None):
self.device = device
self.device_ids = device_ids

self.cpu_energy: float = 0
self.gpu_energy: float = 0
self.ram_energy: float = 0
self.total_energy: float = 0

if self.device == "cuda":
if self.device_ids is None:
LOGGER.warning("\t+ `device=cuda` but `device_ids` not provided. Using all available CUDA devices.")
Expand All @@ -91,6 +89,11 @@ def __init__(self, device: str, device_ids: Optional[str] = None):
self.device_ids = list(map(int, self.device_ids.split(",")))
LOGGER.info(f"\t+ Tracking GPU energy on devices {self.device_ids}")

if is_torch_distributed_available() and torch.distributed.is_initialized():
torch.distributed.barrier()

self.reset()

def reset(self):
self.cpu_energy = 0
self.gpu_energy = 0
Expand Down
17 changes: 8 additions & 9 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,20 @@ def __init__(self, device: str, backend: str):
self.device = device
self.backend = backend

self.distributed = is_torch_distributed_available() and torch.distributed.is_initialized()

self.start_events: List[Union[float, torch.cuda.Event]] = []
self.end_events: List[Union[float, torch.cuda.Event]] = []
self.start_time: float = time.perf_counter()

if self.backend == "pytorch" and self.device == "cuda":
LOGGER.info("\t+ Tracking Pytorch CUDA latency")
else:
LOGGER.info("\t+ Tracking CPU latency")

if is_torch_distributed_available() and torch.distributed.is_initialized():
torch.distributed.barrier()

self.reset()

def reset(self):
self.start_time = time.perf_counter()
self.start_events = []
self.end_events = []
self.start_events: List[Union[float, torch.cuda.Event]] = []
self.end_events: List[Union[float, torch.cuda.Event]] = []
self.start_time: float = time.perf_counter()

@contextmanager
def track(self):
Expand Down
8 changes: 7 additions & 1 deletion optimum_benchmark/trackers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from multiprocessing.connection import Connection

from ..system_utils import get_gpu_device_ids, is_nvidia_system, is_rocm_system, get_rocm_version
from ..import_utils import is_pynvml_available, is_amdsmi_available, is_torch_available
from ..import_utils import is_pynvml_available, is_amdsmi_available, is_torch_available, is_torch_distributed_available

if is_torch_distributed_available():
import torch.distributed

if is_nvidia_system() and is_pynvml_available():
import pynvml
Expand Down Expand Up @@ -88,6 +91,9 @@ def __init__(self, device: str, backend: str, device_ids: Optional[str] = None):
)
LOGGER.info(f"\t+ Tracking Allocated/Reserved memory of {num_pytorch_devices} Pytorch CUDA devices")

if is_torch_distributed_available() and torch.distributed.is_initialized():
torch.distributed.barrier()

self.reset()

def reset(self):
Expand Down
15 changes: 3 additions & 12 deletions tests/configs/_base_.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ defaults:
- launcher: process # isolated process launcher
- experiment # inheriting experiment schema
- _self_ # for hydra 1.1 compatibility
- override hydra/hydra_logging: colorlog # colorful logging
- override hydra/job_logging: colorlog # colorful logging
- override hydra/launcher: joblib # for parallelization
- override hydra/hydra_logging: colorlog
- override hydra/job_logging: colorlog

experiment_name: ${device}_${benchmark.name}_${backend.name}_${task}

Expand All @@ -20,13 +19,5 @@ hydra:
# change working directory to the run directory
chdir: true
env_set:
# set environment variable OVERRIDE_BENCHMARKS to 1
# to not skip benchmarks that have been run before
# to not skip benchmarks if results already exist
OVERRIDE_BENCHMARKS: 1

# we are using joblib launcher to parallelize testing since
# we're having ccorrect benchmarks is not important while testing
# to force sequential execution, uncomment the following three lines
# launcher:
# n_jobs: -1 # 1 for debugging
# batch_size: auto # 1 for debugging
10 changes: 7 additions & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from optimum_benchmark.launchers.inline.config import InlineConfig
from optimum_benchmark.backends.pytorch.config import PyTorchConfig
from optimum_benchmark.launchers.process.config import ProcessConfig
from optimum_benchmark.launchers.torchrun.config import TorchrunConfig
from optimum_benchmark.benchmarks.inference.config import INPUT_SHAPES
from optimum_benchmark.benchmarks.training.config import DATASET_SHAPES
from optimum_benchmark.generators.input_generator import InputGenerator
Expand Down Expand Up @@ -36,7 +37,11 @@
("transformers", "image-classification", "google/vit-base-patch16-224"),
("transformers", "semantic-segmentation", "google/vit-base-patch16-224"),
]
LAUNCHER_CONFIGS = [InlineConfig(device_isolation=False), ProcessConfig(device_isolation=False)]
LAUNCHER_CONFIGS = [
InlineConfig(device_isolation=False),
ProcessConfig(device_isolation=False),
TorchrunConfig(device_isolation=False, nproc_per_node=2),
]
BACKENDS = ["pytorch", "none"]
DEVICES = ["cpu", "cuda"]

Expand Down Expand Up @@ -88,8 +93,7 @@ def test_api_memory_tracker(device, backend):
else:
measured_memory = final_memory.max_vram - initial_memory.max_vram
if torch.version.hip is not None:
# skip vram measurement for ROCm
return
return # skip vram measurement for ROCm
else:
measured_memory = final_memory.max_ram - initial_memory.max_ram

Expand Down

0 comments on commit 8131e4c

Please sign in to comment.