diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index e4d1499f44..f4cc8c078f 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -24,6 +24,7 @@ from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to from .gather_excite import GatherExcite from .global_context import GlobalContext +from .grid import ndgrid, meshgrid from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple from .inplace_abn import InplaceAbn from .linear import Linear diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 1ab1c8f5ba..289245f5ad 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -18,10 +18,18 @@ import torch.nn as nn import torch.nn.functional as F +from .grid import ndgrid + def drop_block_2d( - x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, - with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + x, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False +): """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. This layer has been tested on a few training @@ -35,7 +43,7 @@ def drop_block_2d( (W - block_size + 1) * (H - block_size + 1)) # Forces the block to be inside the feature map. - w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device)) valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) @@ -68,8 +76,13 @@ def drop_block_2d( def drop_block_fast_2d( - x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, - gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False): + x: torch.Tensor, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, +): """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid diff --git a/timm/layers/grid.py b/timm/layers/grid.py new file mode 100644 index 0000000000..f760d761fd --- /dev/null +++ b/timm/layers/grid.py @@ -0,0 +1,49 @@ +from typing import Tuple + +import torch + + +def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]: + """generate N-D grid in dimension order. + + The ndgrid function is like meshgrid except that the order of the first two input arguments are switched. + + That is, the statement + [X1,X2,X3] = ndgrid(x1,x2,x3) + + produces the same result as + + [X2,X1,X3] = meshgrid(x2,x1,x3) + + This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make + torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy'). + + """ + try: + return torch.meshgrid(*tensors, indexing='ij') + except TypeError: + # old PyTorch < 1.10 will follow this path as it does not have indexing arg, + # the old behaviour of meshgrid was 'ij' + return torch.meshgrid(*tensors) + + +def meshgrid(*tensors) -> Tuple[torch.Tensor, ...]: + """generate N-D grid in spatial dim order. + + The meshgrid function is similar to ndgrid except that the order of the + first two input and output arguments is switched. + + That is, the statement + + [X,Y,Z] = meshgrid(x,y,z) + produces the same result as + + [Y,X,Z] = ndgrid(y,x,z) + Because of this, meshgrid is better suited to problems in two- or three-dimensional Cartesian space, + while ndgrid is better suited to multidimensional problems that aren't spatially based. + """ + + # NOTE: this will throw in PyTorch < 1.10 as meshgrid did not support indexing arg or have + # capability of generating grid in xy order before then. + return torch.meshgrid(*tensors, indexing='xy') + diff --git a/timm/layers/lambda_layer.py b/timm/layers/lambda_layer.py index e50b43c8c5..9192e266e6 100644 --- a/timm/layers/lambda_layer.py +++ b/timm/layers/lambda_layer.py @@ -24,13 +24,14 @@ from torch import nn import torch.nn.functional as F +from .grid import ndgrid from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ def rel_pos_indices(size): size = to_2tuple(size) - pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) + pos = torch.stack(ndgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) rel_pos = pos[:, None, :] - pos[:, :, None] rel_pos[0] += size[0] - 1 rel_pos[1] += size[1] - 1 diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py index 42a3b2801e..4fcb111e99 100644 --- a/timm/layers/pos_embed_rel.py +++ b/timm/layers/pos_embed_rel.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.nn.functional as F +from .grid import ndgrid from .interpolate import RegularGridInterpolator from .mlp import Mlp from .weight_init import trunc_normal_ @@ -26,12 +27,7 @@ def gen_relative_position_index( # get pair-wise relative position index for each token inside the window assert k_size is None, 'Different q & k sizes not currently supported' # FIXME - coords = torch.stack( - torch.meshgrid([ - torch.arange(q_size[0]), - torch.arange(q_size[1]) - ]) - ).flatten(1) # 2, Wh, Ww + coords = torch.stack(ndgrid(torch.arange(q_size[0]), torch.arange(q_size[1]))).flatten(1) # 2, Wh, Ww relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0 @@ -42,16 +38,16 @@ def gen_relative_position_index( # else: # # FIXME different q vs k sizes is a WIP, need to better offset the two grids? # q_coords = torch.stack( - # torch.meshgrid([ + # ndgrid( # torch.arange(q_size[0]), # torch.arange(q_size[1]) - # ]) + # ) # ).flatten(1) # 2, Wh, Ww # k_coords = torch.stack( - # torch.meshgrid([ + # ndgrid( # torch.arange(k_size[0]), # torch.arange(k_size[1]) - # ]) + # ) # ).flatten(1) # relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww # relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 @@ -232,7 +228,7 @@ def _calc(src, dst): tx = dst_size[1] // 2.0 dy = torch.arange(-ty, ty + 0.1, 1.0) dx = torch.arange(-tx, tx + 0.1, 1.0) - dyx = torch.meshgrid([dy, dx]) + dyx = ndgrid(dy, dx) # print("Target positions = %s" % str(dx)) all_rel_pos_bias = [] @@ -313,7 +309,7 @@ def gen_relative_log_coords( # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32) - relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w)) relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2 if mode == 'swin': if pretrained_win_size[0] > 0: diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 4675aba21f..b5f8502f37 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -8,6 +8,7 @@ import torch from torch import nn as nn +from .grid import ndgrid from .trace_utils import _assert @@ -64,10 +65,10 @@ def build_sincos2d_pos_embed( if reverse_coord: feat_shape = feat_shape[::-1] # stack W, H instead of H, W - grid = torch.stack(torch.meshgrid( - [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) - for s in feat_shape]) - ).flatten(1).transpose(0, 1) + grid = torch.stack(ndgrid([ + torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + for s in feat_shape + ])).flatten(1).transpose(0, 1) pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) # FIXME add support for unflattened spatial dim? @@ -137,7 +138,7 @@ def build_fourier_pos_embed( # eva's scheme for resizing rope embeddings (ref shape = pretrain) t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)] - grid = torch.stack(torch.meshgrid(t), dim=-1) + grid = torch.stack(ndgrid(t), dim=-1) grid = grid.unsqueeze(-1) pos = grid * bands diff --git a/timm/models/beit.py b/timm/models/beit.py index 663dcc4bd4..0167099ce7 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -48,7 +48,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn -from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table +from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid from ._builder import build_model_with_cfg @@ -63,9 +63,7 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window window_area = window_size[0] * window_size[1] - coords = torch.stack(torch.meshgrid( - [torch.arange(window_size[0]), - torch.arange(window_size[1])])) # 2, Wh, Ww + coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index fe869d6654..727fcac31b 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -18,7 +18,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp +from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -63,7 +63,7 @@ def __init__( self.proj = nn.Linear(self.val_attn_dim, dim) resolution = to_2tuple(resolution) - pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) + pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index 357b258dec..8f76bed369 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -23,7 +23,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct -from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple +from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -129,7 +129,7 @@ def __init__( self.act = act_layer() self.proj = ConvNorm(self.dh, dim, 1) - pos = torch.stack(torch.meshgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1) + pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1) rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1] self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N)) @@ -231,12 +231,11 @@ def __init__( self.proj = ConvNorm(self.dh, self.out_dim, 1) self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N)) - k_pos = torch.stack(torch.meshgrid(torch.arange( - self.resolution[0]), - torch.arange(self.resolution[1]))).flatten(1) - q_pos = torch.stack(torch.meshgrid( + k_pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1) + q_pos = torch.stack(ndgrid( torch.arange(0, self.resolution[0], step=2), - torch.arange(0, self.resolution[1], step=2))).flatten(1) + torch.arange(0, self.resolution[1], step=2) + )).flatten(1) rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs() rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1] self.register_buffer('attention_bias_idxs', rel_pos, persistent=False) diff --git a/timm/models/levit.py b/timm/models/levit.py index 71ce4a364f..ca0708bd59 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -31,7 +31,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_ +from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -194,7 +194,7 @@ def __init__( ])) self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) - pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) + pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] self.register_buffer('attention_bias_idxs', rel_pos, persistent=False) @@ -290,10 +290,11 @@ def __init__( ])) self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) - k_pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) - q_pos = torch.stack(torch.meshgrid( + k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) + q_pos = torch.stack(ndgrid( torch.arange(0, resolution[0], step=stride), - torch.arange(0, resolution[1], step=stride))).flatten(1) + torch.arange(0, resolution[1], step=stride) + )).flatten(1) rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs() rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] self.register_buffer('attention_bias_idxs', rel_pos, persistent=False) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 471b94e84e..bb3f9508b9 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -24,7 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ - _assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed + _assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply @@ -78,7 +78,7 @@ def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int): def get_relative_position_index(win_h: int, win_w: int): # get pair-wise relative position index for each token inside the window - coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww + coords = torch.stack(ndgrid(torch.arange(win_h), torch.arange(win_w))) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 0815856b27..b152b54470 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -22,7 +22,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\ - resample_patch_embed + resample_patch_embed, ndgrid from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -107,9 +107,8 @@ def __init__( # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32) - relative_coords_table = torch.stack(torch.meshgrid([ - relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w)) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) @@ -125,7 +124,7 @@ def __init__( # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords = torch.stack(ndgrid(coords_h, coords_w)) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index cb26dde0e9..1aae86459f 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -37,7 +37,7 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert +from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import named_apply @@ -141,9 +141,10 @@ def __init__( def _make_pair_wise_relative_positions(self) -> None: """Method initializes the pair-wise relative positions to compute the positional biases.""" device = self.logit_scale.device - coordinates = torch.stack(torch.meshgrid([ + coordinates = torch.stack(ndgrid( torch.arange(self.window_size[0], device=device), - torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) + torch.arange(self.window_size[1], device=device) + ), dim=0).flatten(1) relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :] relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() relative_coordinates_log = torch.sign(relative_coordinates) * torch.log(