Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove call to F.pad, improved calculation of memory_count #10620

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

bm-synth
Copy link

@bm-synth bm-synth commented Jan 21, 2025

  • remove one call to symmetric padding in F.pad when running with non-replicate pad mode, and instead let padding be done by Conv3d for a more efficient execution;
  • computation of memory_count doesn't extend dimensions to allow torch.compile to do a better optimisation (?) by @ic-synth

cc: @jamesbriggs-synth

@bm-synth bm-synth changed the title Inplace sums, remove call to F.pad and better memory count Inplace sums, remove call to F.pad, improved calculation of memory Jan 21, 2025
@bm-synth bm-synth changed the title Inplace sums, remove call to F.pad, improved calculation of memory Inplace sums, remove call to F.pad, improved calculation of memory_count Jan 21, 2025
@bm-synth bm-synth marked this pull request as ready for review January 21, 2025 12:01
@bm-synth bm-synth changed the title Inplace sums, remove call to F.pad, improved calculation of memory_count in-place sums, remove call to F.pad, improved calculation of memory_count Jan 21, 2025
@hlky
Copy link
Collaborator

hlky commented Jan 22, 2025

Hi @bm-synth. Thanks for your contribution. Can you share some figures on the memory and performance improvements?

@brunomaga
Copy link

brunomaga commented Jan 24, 2025

Hi @hlky.

Running the following test_autoencoder.py

import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3d

torch.manual_seed(42)

def train(model: nn.Module, video_input: torch.Tensor):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    model.train()
    start_train = time.time()
    for iteration in range(100):  # Simulate 100 training iterations
        optimizer.zero_grad()
        output = model(video_input)[0]
        loss = F.mse_loss(output, output+iteration) # sum iteration to fake different grads per iteration
        loss.backward()
        optimizer.step()
        torch.cuda.synchronize()
    train_time = time.time() - start_train
    print("train_time", train_time, "secs")
    return output.to("cpu")


def eval(model: nn.Module, video_input: torch.Tensor):
    model.eval()
    start_train = time.time()
    with torch.no_grad():
        for _ in range(300):  # Simulate 300 inference iterations
            model(video_input)
            torch.cuda.synchronize()
    eval_time = time.time() - start_train
    print("eval_time", eval_time, "secs")

calling with that input shape [1, 128, 8, 544, 960], on the main branch, gives:

$ PYTHONPATH=./diffusers_main/src/ python test_autoencoder.py
input size:  0.498046875 GBs
eval_time 33.06385564804077 secs
train_time 34.33984375 secs
Max memory 22.18018913269043 GBs

calling this PR branch gives:

$ PYTHONPATH=./diffusers_PR/src/ python test_autoencoder.py
input size:  0.498046875 GBs
eval_time 31.588099241256714 secs
train_time 34.1251916885376 secs
Max memory 22.17398452758789 GBs

on the shape (1, 3, 300, 544, 960), main branch:

$ PYTHONPATH=./diffusers_main/src/ python test_autoencoder.py
input size:  0.43773651123046875 GBs
eval_time 17.759469032287598 secs
train_time 96.50320744514465 secs
Max memory 16.353439331054688 GBs

and this PR:

$ PYTHONPATH=./diffusers_PR/src/ python test_autoencoder.py
input size:  0.43773651123046875 GBs
eval_time 16.8880774974823 secs
train_time 96.04004764556885 secs
Max memory 16.34803009033203 GBs

I'll try to test more dimensions.

@bm-synth bm-synth changed the title in-place sums, remove call to F.pad, improved calculation of memory_count remove call to F.pad, improved calculation of memory_count Jan 25, 2025
@hlky
Copy link
Collaborator

hlky commented Jan 27, 2025

@bm-synth Great, thanks. Would it also be possible to verify numerical accuracy between the two versions? For a change like this we would expect between 0 to 1e-6 difference.

@brunomaga
Copy link

brunomaga commented Jan 27, 2025

@hlky I updated the code above to fix a seed (torch.manual_seed(42)) and save the tensor with the model output after 100 training iterations. Then I ran this to compare both output_*.pt files:

if __name__=='__main__':
    output_main: torch.Tensor = torch.load("output_main.pt")
    output_PR: torch.Tensor = torch.load("output_PR.pt")
    print("mean:", output_main.mean().item(), "vs", output_PR.mean().item())
    print("std:", output_main.std().item(), "vs", output_PR.std().item())
    print("max abs diff:", (output_PR-output_main).diff().abs().max().item())
    assert torch.allclose(output_main, output_PR)

output:

mean: -8.058547973632812e-05 vs -8.058547973632812e-05
std: 0.578125 vs 0.578125
max abs diff: 0.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants