Skip to content

Commit

Permalink
removing debugging code
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchengrex committed Dec 30, 2023
1 parent 8f06bd7 commit 83f1c1d
Showing 1 changed file with 1 addition and 51 deletions.
52 changes: 1 addition & 51 deletions cutie/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import torch
import torch.distributed as distributed
import pynvml

from cutie.model.trainer import Trainer
from cutie.dataset.setup_training_data import setup_pre_training_datasets, setup_main_training_datasets
Expand Down Expand Up @@ -44,55 +43,11 @@ def train(cfg: DictConfig):
info_if_rank_zero(f'All configuration: {cfg}')
info_if_rank_zero(f'Number of detected GPUs: {num_gpus}')

acquire = cfg.acquire

# cuda setup
torch.cuda.set_device(local_rank)
if cfg.cudnn_benchmark:
torch.backends.cudnn.benchmark = True

temp = []

def wait():
cuda_id = torch.cuda.current_device()
if "CUDA_VISIBLE_DEVICES" in os.environ:
ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
cuda_id = ids[cuda_id] # remap
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(cuda_id)
free = pynvml.nvmlDeviceGetMemoryInfo(handle).free
free = free // (2**28) - 1
print(f'{local_rank}: Waiting... {free}')
while free < 7:
free = pynvml.nvmlDeviceGetMemoryInfo(handle).free
free = free // (2**28) - 1
torch.cuda.empty_cache()
time.sleep(60)
print(f'{local_rank}: Barrier... {free}')
distributed.barrier()
print(f'{local_rank}: Go! {free}')

wait()

def occupy(temp):
if not acquire:
return []
# block
cuda_id = torch.cuda.current_device()
if "CUDA_VISIBLE_DEVICES" in os.environ:
ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
cuda_id = ids[cuda_id] # remap
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(cuda_id)
free = pynvml.nvmlDeviceGetMemoryInfo(handle).free
free = free // (2**28) - 3
if free > 0:
print(f'{local_rank}: Occupying... {free}')
temp.extend(
[torch.zeros((2**28), dtype=torch.uint8, device='cuda') for _ in range(free)])
print('After', len(temp), pynvml.nvmlDeviceGetMemoryInfo(handle).free)
return temp

# number of dataloader workers
cfg.num_workers //= num_gpus
info_if_rank_zero(f'Number of dataloader workers (per GPU): {cfg.num_workers}')
Expand Down Expand Up @@ -197,18 +152,13 @@ def occupy(temp):
trainer.do_pass(data, curr_iter)
curr_iter += 1

if curr_iter % 100 == 0:
temp = occupy(temp)

if curr_iter >= total_iterations:
break
finally:
if not cfg.debug:
trainer.save_weights(curr_iter)
trainer.save_checkpoint(curr_iter)

del temp
temp = []
torch.cuda.empty_cache()
weights_in_memory = trainer.weights()

Expand All @@ -217,4 +167,4 @@ def occupy(temp):


if __name__ == '__main__':
train()
train()

0 comments on commit 83f1c1d

Please sign in to comment.