Skip to content

Commit

Permalink
Improve MobileViT
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jan 12, 2024
1 parent 24e601a commit e7dafbe
Showing 1 changed file with 13 additions and 21 deletions.
34 changes: 13 additions & 21 deletions kimm/models/mobilevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@


def unfold(inputs, patch_size):
# TODO: improve performance
x = inputs
h, w, c = x.shape[-3], x.shape[-2], x.shape[-1]
new_h, new_w = (
Expand All @@ -50,40 +51,32 @@ def unfold(inputs, patch_size):
num_patches_h = new_h // patch_size
num_patches_w = new_w // patch_size
num_patches = num_patches_h * num_patches_w
patch_area = int(patch_size * patch_size)
# [B, H, W, C] -> [B, C, H, W]
x = ops.transpose(x, [0, 3, 1, 2])
# [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
x = ops.reshape(x, [-1, patch_size, num_patches_w, patch_size])
x = ops.swapaxes(x, 1, 2)
# [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C]
# where P = p_h * p_w and N = n_h * n_w
x = ops.reshape(x, [-1, c, num_patches, patch_area])
x = ops.swapaxes(x, 1, 3)
# [B, H, W, C] -> [B * P, N, C]
x = ops.reshape(
x, [-1, num_patches_h, patch_size, num_patches_w, patch_size, c]
)
x = ops.transpose(x, [0, 2, 4, 1, 3, 5])
x = ops.reshape(x, [-1, num_patches, c])
return x


def fold(inputs, h, w, c, patch_size):
# TODO: improve performance
x = inputs
new_h, new_w = (
math.ceil(h / patch_size) * patch_size,
math.ceil(w / patch_size) * patch_size,
)
num_patches_h = new_h // patch_size
num_patches_w = new_w // patch_size
num_patches = num_patches_h * num_patches_w
patch_area = int(patch_size * patch_size)
# [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
x = ops.reshape(x, [-1, patch_area, num_patches, c])
x = ops.swapaxes(x, 1, 3)
x = ops.reshape(x, [-1, num_patches_w, patch_size, patch_size])
# [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
x = ops.swapaxes(x, 1, 2)
# [B * P, N, C] -> [B, P, N, C] -> [B, H, W, C]
x = ops.reshape(
x, [-1, patch_size, patch_size, num_patches_h, num_patches_w, c]
)
x = ops.transpose(x, [0, 3, 1, 4, 2, 5])
x = ops.reshape(
x, [-1, c, num_patches_h * patch_size, num_patches_w * patch_size]
x, [-1, num_patches_h * patch_size, num_patches_w * patch_size, c]
)
x = ops.transpose(x, [0, 2, 3, 1])
return x


Expand Down Expand Up @@ -126,7 +119,6 @@ def apply_mobilevit_block(
)(x)

# Unfold (feature map -> patches)
# TODO: improves performance
h, w, c = x.shape[-3], x.shape[-2], x.shape[-1]
x = unfold(x, patch_size)

Expand Down

0 comments on commit e7dafbe

Please sign in to comment.