Skip to content

Commit

Permalink
Merge pull request #538 from spacegoing/fix_rope_finetune_shape
Browse files Browse the repository at this point in the history
[Fix] fix rope temporal patch size
  • Loading branch information
zRzRzRzRzRzRzR authored Nov 23, 2024
2 parents 2fdc59c + 2fb763d commit d82922c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion finetune/train_cogvideox_image_to_video_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,7 @@ def prepare_rotary_positional_embeddings(
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
patch_size_t: int = 1,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
Expand All @@ -922,12 +923,15 @@ def prepare_rotary_positional_embeddings(
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)

p_t = patch_size_t
base_num_frames = (num_frames + p_t - 1) // p_t

grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
temporal_size=base_num_frames,
)

freqs_cos = freqs_cos.to(device=device)
Expand Down Expand Up @@ -1482,6 +1486,7 @@ def collate_fn(examples):
num_frames=num_frames,
vae_scale_factor_spatial=vae_scale_factor_spatial,
patch_size=model_config.patch_size,
patch_size_t=model_config.patch_size_t,
attention_head_dim=model_config.attention_head_dim,
device=accelerator.device,
)
Expand Down
7 changes: 6 additions & 1 deletion finetune/train_cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ def prepare_rotary_positional_embeddings(
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
patch_size_t: int = 1,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
Expand All @@ -835,12 +836,15 @@ def prepare_rotary_positional_embeddings(
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)

p_t = patch_size_t
base_num_frames = (num_frames + p_t - 1) // p_t

grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
temporal_size=base_num_frames,
)

freqs_cos = freqs_cos.to(device=device)
Expand Down Expand Up @@ -1346,6 +1350,7 @@ def collate_fn(examples):
num_frames=num_frames,
vae_scale_factor_spatial=vae_scale_factor_spatial,
patch_size=model_config.patch_size,
patch_size_t=model_config.patch_size_t,
attention_head_dim=model_config.attention_head_dim,
device=accelerator.device,
)
Expand Down

0 comments on commit d82922c

Please sign in to comment.