diff --git a/finetune/train_cogvideox_image_to_video_lora.py b/finetune/train_cogvideox_image_to_video_lora.py index edbfc0fd..8edd7a80 100644 --- a/finetune/train_cogvideox_image_to_video_lora.py +++ b/finetune/train_cogvideox_image_to_video_lora.py @@ -912,6 +912,7 @@ def prepare_rotary_positional_embeddings( num_frames: int, vae_scale_factor_spatial: int = 8, patch_size: int = 2, + patch_size_t: int = 1, attention_head_dim: int = 64, device: Optional[torch.device] = None, base_height: int = 480, @@ -922,12 +923,15 @@ def prepare_rotary_positional_embeddings( base_size_width = base_width // (vae_scale_factor_spatial * patch_size) base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + p_t = patch_size_t + base_num_frames = (num_frames + p_t - 1) // p_t + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) freqs_cos, freqs_sin = get_3d_rotary_pos_embed( embed_dim=attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, ) freqs_cos = freqs_cos.to(device=device) @@ -1482,6 +1486,7 @@ def collate_fn(examples): num_frames=num_frames, vae_scale_factor_spatial=vae_scale_factor_spatial, patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t, attention_head_dim=model_config.attention_head_dim, device=accelerator.device, ) diff --git a/finetune/train_cogvideox_lora.py b/finetune/train_cogvideox_lora.py index 5d20908e..70c39f07 100644 --- a/finetune/train_cogvideox_lora.py +++ b/finetune/train_cogvideox_lora.py @@ -825,6 +825,7 @@ def prepare_rotary_positional_embeddings( num_frames: int, vae_scale_factor_spatial: int = 8, patch_size: int = 2, + patch_size_t: int = 1, attention_head_dim: int = 64, device: Optional[torch.device] = None, base_height: int = 480, @@ -835,12 +836,15 @@ def prepare_rotary_positional_embeddings( base_size_width = base_width // (vae_scale_factor_spatial * patch_size) base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + p_t = patch_size_t + base_num_frames = (num_frames + p_t - 1) // p_t + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) freqs_cos, freqs_sin = get_3d_rotary_pos_embed( embed_dim=attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, ) freqs_cos = freqs_cos.to(device=device) @@ -1346,6 +1350,7 @@ def collate_fn(examples): num_frames=num_frames, vae_scale_factor_spatial=vae_scale_factor_spatial, patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t, attention_head_dim=model_config.attention_head_dim, device=accelerator.device, )