Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High VRAM usage with Blocks to swap on ROCM #1776

Open
csf3o opened this issue Nov 11, 2024 · 2 comments
Open

High VRAM usage with Blocks to swap on ROCM #1776

csf3o opened this issue Nov 11, 2024 · 2 comments

Comments

@csf3o
Copy link

csf3o commented Nov 11, 2024

Hey, I was testing out flux dreambooth on my 16GB VRAM AMD GPU with blocks to swap = 36, CPU Checkpoint offloading, and Memory Efficient Save.

I see in #1764 a value of 36 on nvidia should enable ~6GB of VRAM usage, instead what I see is ~5.4GB usage when caching latents, then it drops with a long pause of loading state dicts (at ~300MB) while it loads into RAM.

It then starts rising slowly to ~9.6GB, before it reaches

2024-11-11 23:36:18 INFO     Loaded Flux: <All keys matched successfully>      flux_utils.py:137
FLUX: Gradient checkpointing enabled. CPU offload: True
2024-11-11 23:36:19 INFO     enable block swap: blocks_to_swap=36              flux_train.py:295
number of trainable parameters: 11901408320
prepare optimizer, data loader etc.
                    INFO     use Adafactor optimizer | {'relative_step':      train_util.py:4764
                             False, 'scale_parameter': False, 'warmup_init':                    
                             False}                                                             
                    WARNING  because max_grad_norm is set, clip_grad_norm is  train_util.py:4792
                             enabled. consider set to 0 /                                       
                             max_grad_norm____________________clip_grad_norm                    
                             __________________0____________________________                    
                             ________________                                                   
                    WARNING  constant_with_warmup will be good /              train_util.py:4796
                             ______________constant_with_warmup______________                   
                             ______                                                             
override steps. steps for 2 epochs is / ____________________________: 222

It then quickly rises to ~11GB of usage, printing

[2024-11-11 23:38:03,458] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
running training / ________                                                                     
  num examples / __________: 111                                                                                                                                                                
  num batches per epoch / 1epoch__________: 111                                                                                                                                                 
  num epochs / epoch__: 2                                                                                                                                                                       
  batch size per device / ____________: 1                                                       
  gradient accumulation steps / ________________________ = 1
  total optimization steps / ______________: 222                                                
steps:   0%|                                                            | 0/222 [00:00<?, ?it/s] epoch 1/2
2024-11-11 23:38:18 INFO     epoch is incremented. current_epoch: 0, epoch: 1  train_util.py:715

And then it spikes up to 15GB and ultimately fails to allocate 1.85GB printing the traceback:

Traceback
/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:125: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
  return F.linear(input, self.weight, self.bias)                                                                                                                                                
Traceback (most recent call last):
  File "/home/csf3o/kohya_ss/sd-scripts/flux_train.py", line 998, in <module>                                                                                                 
    train(args) 
  File "/home/csf3o/kohya_ss/sd-scripts/flux_train.py", line 787, in train
    model_pred = flux(
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 1084, in forward
    img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 751, in forward
    return torch.utils.checkpoint.checkpoint(
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 496, in checkpoint
    ret = function(*args, **kwargs)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 746, in custom_forward
    outputs = func(*cuda_inputs)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 723, in _forward
    attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 449, in attention
    x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
torch.OutOfMemoryError: HIP out of memory. Tried to allocate 1.85 GiB. GPU 0 has a total capacity of 15.98 GiB of which 398.00 MiB is free. Of the allocated memory 15.21 GiB is allocated by PyTorch, and 52.63 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_HIP_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

I've tried a few different configurations, like turning on/off sdpa, enabling and disabling full fp16 training.

The command used:

/home/csf3o/kohya_ss/venv/bin/accelerate launch --dynamo_backend no --dynamo_mode default \
--gpu_ids 0 --mixed_precision fp16 --num_processes 1 --num_machines 1 --num_cpu_threads_per_process 2 \
/home/csf3o/kohya_ss/sd-scripts/flux_train.py --config_file /home/csf3o/kohya_ss/outputs/abc1/config.toml \
--text_encoder_batch_size 1
config.toml
adaptive_noise_scale = 0                                                                        
ae = "/home/csf3o/kohya_ss/models/Flux1-dev/ae.safetensors"
blocks_to_swap = 36                                                                             
bucket_no_upscale = true  
bucket_reso_steps = 64            
cache_latents = true                                                                                                                                                                            
cache_latents_to_disk = true
caption_dropout_every_n_epochs = 0
caption_dropout_rate = 0                                                                        
caption_extension = ".caption"
clip_l = "/home/csf3o/kohya_ss/models/Flux1-dev/clip_l.safetensors"
clip_skip = 1                
cpu_offload_checkpointing = true
discrete_flow_shift = 3
double_blocks_to_swap = 0
dynamo_backend = "no"   
enable_bucket = true                                                                            
epoch = 1                   
fused_backward_pass = true 
gradient_accumulation_steps = 1
guidance_scale = 3.5
huber_c = 0.1                                                                                   
huber_schedule = "snr" 
keep_tokens = 0
learning_rate = 1e-5
learning_rate_te = 1e-5
logging_dir = "/home/csf3o/kohya_ss/logs/abc1"
loss_type = "l2"
lr_scheduler = "adafactor"
lr_scheduler_args = []
lr_scheduler_num_cycles = 1
lr_scheduler_power = 1
lr_warmup_steps = 0
max_bucket_reso = 2048
max_data_loader_n_workers = 0
max_timestep = 1000
max_token_length = 75
max_train_epochs = 2
max_train_steps = 1073
mem_eff_save = true
min_bucket_reso = 256
mixed_precision = "fp16"
model_prediction_type = "sigma_scaled"
multires_noise_discount = 0.3
multires_noise_iterations = 0
noise_offset = 0
noise_offset_type = "Original"
optimizer_args = [ "relative_step=False", "scale_parameter=False", "warmup_init=False",]
optimizer_type = "Adafactor"
output_dir = "/home/csf3o/kohya_ss/outputs/abc1"
output_name = "abc1_last"
persistent_data_loader_workers = 0
pretrained_model_name_or_path = "/home/csf3o/kohya_ss/models/Flux1-dev/flux_dev.safetensors"
prior_loss_weight = 1
resolution = "1024,1024"
sample_prompts = "/home/csf3o/kohya_ss/outputs/abc1/sample/prompt.txt"
sample_sampler = "euler_a"
save_every_n_epochs = 1
save_model_as = "safetensors"
save_precision = "fp16"
sdpa = false
single_blocks_to_swap = 0
skip_cache_check = false
t5xxl = "/home/csf3o/kohya_ss/models/Flux1-dev/t5xxl_fp16.safetensors"
t5xxl_max_token_length = 512
timestep_sampling = "sigma"
train_batch_size = 1
train_blocks = "all"
train_data_dir = "/home/csf3o/kohya_ss/dataset/images/abc1"
wandb_run_name = "abc1_last"

The commit I am using: 264328d117dc5d17772ec0bdbac2b9f0cf4695f5

If you need any more detail or if I can help in any other way to test I would be more than happy to do so.

Or maybe I have some wrong settings, which in that case I'm sorry for any trouble I may have caused.

Thank you in advance!

@kohya-ss
Copy link
Owner

Unfortunately I don't know why it doesn't work with ROCm. Could you try the faster_block_swap branch, which implements block swap in a simpler way and might work? In addition, --disable_mmap_load_safetensors may help shorten long pauses.

@csf3o
Copy link
Author

csf3o commented Nov 12, 2024

That does seem to help a little, I lowered the blocks to swap to 33, and this was my results:

In one attempt the following happened:

At first it climbed as normal to 12.3GiB. (instead of 11 as in the old 36 experiment)

Here it stabilized, growing with a few MB only, going up and down, trending maybe a MB every few seconds in what could look like some sort of memory leak. I assume this is the swapping actually working.. But then after a few more seconds of this (and no visible progress on an iteration) it SIGABRT'ed with:

Memory access fault by GPU node-1 (Agent handle: 0x55b59badccd0) on address 0x7f6eca614000. Reason: Page not present or supervisor privilege.

Which to me looks like some kind of segfault.

This doesn't always happen however, when I try to run it again, it stabilizes similarly, but instead of segfaulting, it prints:

epoch 1/2
2024-11-12 09:20:49 INFO     epoch is incremented. current_epoch: 0, epoch: 1

then crashes with the traceback:

Traceback
train_util.py:715
/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:125: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
  return F.linear(input, self.weight, self.bias)
Traceback (most recent call last):                                                              
  File "/home/csf3o/kohya_ss/sd-scripts/flux_train.py", line 891, in <module>                                                                                                 
    train(args)                                                                                                                                                                                 
  File "/home/csf3o/kohya_ss/sd-scripts/flux_train.py", line 680, in train                                                                                                    
    model_pred = flux(                                                                                                                                                                          
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                  
    return self._call_impl(*args, **kwargs)                                                     
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 1036, in forward
    img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 753, in forward
    return torch.utils.checkpoint.checkpoint(
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/home/csf3o/kohya_ss/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 496, in checkpoint
    ret = function(*args, **kwargs)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 748, in custom_forward
    outputs = func(*cuda_inputs)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 725, in _forward
    attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
  File "/home/csf3o/kohya_ss/sd-scripts/library/flux_models.py", line 451, in attention
    x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
torch.OutOfMemoryError: HIP out of memory. Tried to allocate 1.85 GiB. GPU 0 has a total capacity of 15.98 GiB of which 1.01 GiB is free. Of the allocated memory 14.59 GiB is allocated by PyTo
rch, and 50.20 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_HIP_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See 
documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

This is a little better: ~1GB free instead of ~400MB, but still falls a little short of the 1.85GB it wants.

The behavior is similar with Blockwise Fused Optimizer and AdamW as well.

Maybe there is some hope though due to how the blockswapping looks like it might actually be somewhat working?
I'm not sure if that's what it actually is though, since I'm not sure what phase it fails in. it doesn't seem to really get started. (my completely uneducated guess is maybe back-propagation?)

Oh and I forgot to mention (switching branch reminded me). To get adafactor working I needed to run this patch:

diff --git a/library/train_util.py b/library/train_util.py
index 8b5cf21..d6ff231 100644
--- a/library/train_util.py
+++ b/library/train_util.py
@@ -5014,7 +5014,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
         assert (
             type(optimizer) == transformers.optimization.Adafactor
         ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor scheduler<E3><81><AF>Adafactor<E3><82><AA><E3><83><97><E3><83><86><E3><82><A3><E3><83><9E><E3><82><A4><E3><8<82><B6><E3><81><A8><E5><90><8C><E6><99><82><E3><81><AB><E4><BD><BF><E3><81><A3><E3><81><A6><E3><81><8F><E3><81><A0><E3><81><95><E3><81><84>"
-        initial_lr = float(name.split(":")[1])
+        initial_lr = args.learning_rate
         # logger.info(f"adafactor scheduler init lr {initial_lr}")
         return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))

Thanks again for the quick answer!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants