Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove call to F.pad, improved calculation of memory_count #10620

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
9 changes: 4 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ class CogVideoXSafeConv3d(nn.Conv3d):
"""

def forward(self, input: torch.Tensor) -> torch.Tensor:
memory_count = (
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
)
memory_count = torch.prod(torch.tensor(input.shape)) * 2 / 1024**3

# Set to 2GB, suitable for CuDNN
if memory_count > 2:
Expand Down Expand Up @@ -105,6 +103,7 @@ def __init__(
self.width_pad = width_pad
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.const_padding_conv3d = (0, self.width_pad, self.height_pad)

self.temporal_dim = 2
self.time_kernel_size = time_kernel_size
Expand All @@ -117,6 +116,8 @@ def __init__(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
padding_mode="zeros",
)

def fake_context_parallel_forward(
Expand All @@ -137,9 +138,7 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non
if self.pad_mode == "replicate":
conv_cache = None
else:
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)

output = self.conv(inputs)
return output, conv_cache
Expand Down