Skip to content

Commit

Permalink
use device_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 19, 2024
1 parent ff729ba commit 9db3d5f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions optimum_benchmark/trackers/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ def track(self, interval=1, file_prefix="method"):
)

if self.distributed:
torch.distributed.monitored_barrier()
torch.distributed.barrier(device_ids=[torch.distributed.get_rank() % torch.cuda.device_count()])

self.emission_tracker.start()
yield
self.emission_tracker.stop()

if self.distributed:
torch.distributed.monitored_barrier()
torch.distributed.barrier(device_ids=[torch.distributed.get_rank() % torch.cuda.device_count()])

self.cpu_energy = self.emission_tracker._total_cpu_energy.kWh
self.gpu_energy = self.emission_tracker._total_gpu_energy.kWh
Expand Down
4 changes: 2 additions & 2 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ def reset(self):
@contextmanager
def track(self):
if self.distributed:
torch.distributed.monitored_barrier()
torch.distributed.barrier(device_ids=[torch.distributed.get_rank() % torch.cuda.device_count()])

if self.backend == "pytorch" and self.device == "cuda":
yield from self._pytorch_cuda_latency()
else:
yield from self._cpu_latency()

if self.distributed:
torch.distributed.monitored_barrier()
torch.distributed.barrier(device_ids=[torch.distributed.get_rank() % torch.cuda.device_count()])

def _pytorch_cuda_latency(self):
start = torch.cuda.Event(enable_timing=True)
Expand Down
4 changes: 2 additions & 2 deletions optimum_benchmark/trackers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def reset(self):
@contextmanager
def track(self):
if self.distributed:
torch.distributed.monitored_barrier()
torch.distributed.barrier(device_ids=[torch.distributed.get_rank() % torch.cuda.device_count()])

if self.device == "cuda" and self.backend == "pytorch":
yield from self._cuda_pytorch_memory()
Expand All @@ -113,7 +113,7 @@ def track(self):
yield from self._cpu_memory()

if self.distributed:
torch.distributed.monitored_barrier()
torch.distributed.barrier(device_ids=[torch.distributed.get_rank() % torch.cuda.device_count()])

def _cuda_pytorch_memory(self):
torch.cuda.empty_cache()
Expand Down

0 comments on commit 9db3d5f

Please sign in to comment.