From e7dafbe9b5dad3bdc3f4b26bff1f666ff7f20d2a Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:37:28 +0800 Subject: [PATCH] Improve `MobileViT` --- kimm/models/mobilevit.py | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index 5c7b970..91dc807 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -41,6 +41,7 @@ def unfold(inputs, patch_size): + # TODO: improve performance x = inputs h, w, c = x.shape[-3], x.shape[-2], x.shape[-1] new_h, new_w = ( @@ -50,21 +51,17 @@ def unfold(inputs, patch_size): num_patches_h = new_h // patch_size num_patches_w = new_w // patch_size num_patches = num_patches_h * num_patches_w - patch_area = int(patch_size * patch_size) - # [B, H, W, C] -> [B, C, H, W] - x = ops.transpose(x, [0, 3, 1, 2]) - # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w] - x = ops.reshape(x, [-1, patch_size, num_patches_w, patch_size]) - x = ops.swapaxes(x, 1, 2) - # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] - # where P = p_h * p_w and N = n_h * n_w - x = ops.reshape(x, [-1, c, num_patches, patch_area]) - x = ops.swapaxes(x, 1, 3) + # [B, H, W, C] -> [B * P, N, C] + x = ops.reshape( + x, [-1, num_patches_h, patch_size, num_patches_w, patch_size, c] + ) + x = ops.transpose(x, [0, 2, 4, 1, 3, 5]) x = ops.reshape(x, [-1, num_patches, c]) return x def fold(inputs, h, w, c, patch_size): + # TODO: improve performance x = inputs new_h, new_w = ( math.ceil(h / patch_size) * patch_size, @@ -72,18 +69,14 @@ def fold(inputs, h, w, c, patch_size): ) num_patches_h = new_h // patch_size num_patches_w = new_w // patch_size - num_patches = num_patches_h * num_patches_w - patch_area = int(patch_size * patch_size) - # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w] - x = ops.reshape(x, [-1, patch_area, num_patches, c]) - x = ops.swapaxes(x, 1, 3) - x = ops.reshape(x, [-1, num_patches_w, patch_size, patch_size]) - # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] - x = ops.swapaxes(x, 1, 2) + # [B * P, N, C] -> [B, P, N, C] -> [B, H, W, C] + x = ops.reshape( + x, [-1, patch_size, patch_size, num_patches_h, num_patches_w, c] + ) + x = ops.transpose(x, [0, 3, 1, 4, 2, 5]) x = ops.reshape( - x, [-1, c, num_patches_h * patch_size, num_patches_w * patch_size] + x, [-1, num_patches_h * patch_size, num_patches_w * patch_size, c] ) - x = ops.transpose(x, [0, 2, 3, 1]) return x @@ -126,7 +119,6 @@ def apply_mobilevit_block( )(x) # Unfold (feature map -> patches) - # TODO: improves performance h, w, c = x.shape[-3], x.shape[-2], x.shape[-1] x = unfold(x, patch_size)