Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP support #1775

Open
ljleb opened this issue Nov 11, 2024 · 1 comment
Open

FSDP support #1775

ljleb opened this issue Nov 11, 2024 · 1 comment

Comments

@ljleb
Copy link

ljleb commented Nov 11, 2024

I tried using a FSDP config like this for accelerate (taken from #1480 (comment)) to finetune SDXL. The UI is bmaltais/kohya_ss

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: SIZE_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
  fsdp_min_num_params: 100000000
machine_rank: 0
main_training_function: main
mixed_precision: fp32
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

But it gives me this error:

[...]
number of trainable parameters: 3385184004
prepare optimizer, data loader etc.
running training / 学習開始
  num examples / サンプル数: 100
  num batches per epoch / 1epochのバッチ数: 100
  num epochs / epoch数: 1
  batch size per device / バッチサイズ: 1
  gradient accumulation steps / 勾配を合計するステップ数 = 1
  total optimization steps / 学習ステップ数: 100
steps:   0%|                                                                                                 | 0/100 [00:00<?, ?it/s]
epoch 1/1
2024-11-11 08:08:31 INFO     epoch is incremented. current_epoch: 0, epoch: 1                                       train_util.py:703
Traceback (most recent call last):
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/sdxl_train.py", line 822, in <module>
    train(args)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/sdxl_train.py", line 614, in train
    noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 1104, in forward
    h = call_module(module, h, emb, context)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 1097, in call_module
    x = layer(x)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 750, in forward
    hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 669, in forward
    output = torch.utils.checkpoint.checkpoint(
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 230, in forward
    outputs = run_function(*args)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 665, in custom_forward
    return func(*inputs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 652, in forward_body
    hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 445, in forward
    return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 524, in forward_memory_efficient_mem_eff
    k = self.to_k(context)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (576x1280 and 2048x1280)
steps:   0%|                                                                                                 | 0/100 [00:00<?, ?it/s]
[2024-11-11 08:08:34,959] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 232842) of binary: /home/ljleb/src/kohya_ss/venv/bin/python

I have 2 P40s (which means only fp32 is practical; 24GB using a single card appears to be not enough memory) and I would like to distribute the parameters of a single SDXL model over multiple GPUs to reduce the memory usage per card of traditional finetuning and not any type of PEFT.

I am not very familiar with using multiple GPUs to train models. Can FSDP put the first half of SDXL in cuda:0 and the other half in cuda:1 with the existing code?

@ljleb
Copy link
Author

ljleb commented Nov 11, 2024

I figured I can put the text encoders on a different device manually. It's a very ad-hoc solution but it works. I think it would be great if there was a way to partition the load over a larger number of devices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant