Skip to content

Commit

Permalink
Fix meshgrid deprecation warnings and backward compat with explicit '…
Browse files Browse the repository at this point in the history
…ndgrid' and 'meshgrid' fn w/o indexing arg
  • Loading branch information
rwightman committed Jan 27, 2024
1 parent fa247fd commit 88889de
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 51 deletions.
1 change: 1 addition & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .grid import ndgrid, meshgrid
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
Expand Down
23 changes: 18 additions & 5 deletions timm/layers/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@
import torch.nn as nn
import torch.nn.functional as F

from .grid import ndgrid


def drop_block_2d(
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
x,
drop_prob: float = 0.1,
block_size: int = 7,
gamma_scale: float = 1.0,
with_noise: bool = False,
inplace: bool = False,
batchwise: bool = False
):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
Expand All @@ -35,7 +43,7 @@ def drop_block_2d(
(W - block_size + 1) * (H - block_size + 1))

# Forces the block to be inside the feature map.
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
Expand Down Expand Up @@ -68,8 +76,13 @@ def drop_block_2d(


def drop_block_fast_2d(
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False):
x: torch.Tensor,
drop_prob: float = 0.1,
block_size: int = 7,
gamma_scale: float = 1.0,
with_noise: bool = False,
inplace: bool = False,
):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
Expand Down
49 changes: 49 additions & 0 deletions timm/layers/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Tuple

import torch


def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
"""generate N-D grid in dimension order.
The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
That is, the statement
[X1,X2,X3] = ndgrid(x1,x2,x3)
produces the same result as
[X2,X1,X3] = meshgrid(x2,x1,x3)
This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
"""
try:
return torch.meshgrid(*tensors, indexing='ij')
except TypeError:
# old PyTorch < 1.10 will follow this path as it does not have indexing arg,
# the old behaviour of meshgrid was 'ij'
return torch.meshgrid(*tensors)


def meshgrid(*tensors) -> Tuple[torch.Tensor, ...]:
"""generate N-D grid in spatial dim order.
The meshgrid function is similar to ndgrid except that the order of the
first two input and output arguments is switched.
That is, the statement
[X,Y,Z] = meshgrid(x,y,z)
produces the same result as
[Y,X,Z] = ndgrid(y,x,z)
Because of this, meshgrid is better suited to problems in two- or three-dimensional Cartesian space,
while ndgrid is better suited to multidimensional problems that aren't spatially based.
"""

# NOTE: this will throw in PyTorch < 1.10 as meshgrid did not support indexing arg or have
# capability of generating grid in xy order before then.
return torch.meshgrid(*tensors, indexing='xy')

3 changes: 2 additions & 1 deletion timm/layers/lambda_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
from torch import nn
import torch.nn.functional as F

from .grid import ndgrid
from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_


def rel_pos_indices(size):
size = to_2tuple(size)
pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
pos = torch.stack(ndgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
rel_pos = pos[:, None, :] - pos[:, :, None]
rel_pos[0] += size[0] - 1
rel_pos[1] += size[1] - 1
Expand Down
20 changes: 8 additions & 12 deletions timm/layers/pos_embed_rel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F

from .grid import ndgrid
from .interpolate import RegularGridInterpolator
from .mlp import Mlp
from .weight_init import trunc_normal_
Expand All @@ -26,12 +27,7 @@ def gen_relative_position_index(
# get pair-wise relative position index for each token inside the window
assert k_size is None, 'Different q & k sizes not currently supported' # FIXME

coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
coords = torch.stack(ndgrid(torch.arange(q_size[0]), torch.arange(q_size[1]))).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
Expand All @@ -42,16 +38,16 @@ def gen_relative_position_index(
# else:
# # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
# q_coords = torch.stack(
# torch.meshgrid([
# ndgrid(
# torch.arange(q_size[0]),
# torch.arange(q_size[1])
# ])
# )
# ).flatten(1) # 2, Wh, Ww
# k_coords = torch.stack(
# torch.meshgrid([
# ndgrid(
# torch.arange(k_size[0]),
# torch.arange(k_size[1])
# ])
# )
# ).flatten(1)
# relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
# relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
Expand Down Expand Up @@ -232,7 +228,7 @@ def _calc(src, dst):
tx = dst_size[1] // 2.0
dy = torch.arange(-ty, ty + 0.1, 1.0)
dx = torch.arange(-tx, tx + 0.1, 1.0)
dyx = torch.meshgrid([dy, dx])
dyx = ndgrid(dy, dx)
# print("Target positions = %s" % str(dx))

all_rel_pos_bias = []
Expand Down Expand Up @@ -313,7 +309,7 @@ def gen_relative_log_coords(
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
if mode == 'swin':
if pretrained_win_size[0] > 0:
Expand Down
11 changes: 6 additions & 5 deletions timm/layers/pos_embed_sincos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from torch import nn as nn

from .grid import ndgrid
from .trace_utils import _assert


Expand Down Expand Up @@ -64,10 +65,10 @@ def build_sincos2d_pos_embed(

if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
for s in feat_shape])
).flatten(1).transpose(0, 1)
grid = torch.stack(ndgrid([
torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
for s in feat_shape
])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?

Expand Down Expand Up @@ -137,7 +138,7 @@ def build_fourier_pos_embed(
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]

grid = torch.stack(torch.meshgrid(t), dim=-1)
grid = torch.stack(ndgrid(t), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands

Expand Down
6 changes: 2 additions & 4 deletions timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid


from ._builder import build_model_with_cfg
Expand All @@ -63,9 +63,7 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
window_area = window_size[0] * window_size[1]
coords = torch.stack(torch.meshgrid(
[torch.arange(window_size[0]),
torch.arange(window_size[1])])) # 2, Wh, Ww
coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
Expand Down
4 changes: 2 additions & 2 deletions timm/models/efficientformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
self.proj = nn.Linear(self.val_attn_dim, dim)

resolution = to_2tuple(resolution)
pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
Expand Down
13 changes: 6 additions & 7 deletions timm/models/efficientformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
self.act = act_layer()
self.proj = ConvNorm(self.dh, dim, 1)

pos = torch.stack(torch.meshgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N))
Expand Down Expand Up @@ -231,12 +231,11 @@ def __init__(
self.proj = ConvNorm(self.dh, self.out_dim, 1)

self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N))
k_pos = torch.stack(torch.meshgrid(torch.arange(
self.resolution[0]),
torch.arange(self.resolution[1]))).flatten(1)
q_pos = torch.stack(torch.meshgrid(
k_pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
q_pos = torch.stack(ndgrid(
torch.arange(0, self.resolution[0], step=2),
torch.arange(0, self.resolution[1], step=2))).flatten(1)
torch.arange(0, self.resolution[1], step=2)
)).flatten(1)
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
Expand Down
11 changes: 6 additions & 5 deletions timm/models/levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
]))

self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
Expand Down Expand Up @@ -290,10 +290,11 @@ def __init__(
]))

self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
k_pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
q_pos = torch.stack(torch.meshgrid(
k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
q_pos = torch.stack(ndgrid(
torch.arange(0, resolution[0], step=stride),
torch.arange(0, resolution[1], step=stride))).flatten(1)
torch.arange(0, resolution[1], step=stride)
)).flatten(1)
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
Expand Down
4 changes: 2 additions & 2 deletions timm/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
_assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed
_assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply
Expand Down Expand Up @@ -78,7 +78,7 @@ def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int):

def get_relative_position_index(win_h: int, win_w: int):
# get pair-wise relative position index for each token inside the window
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww
coords = torch.stack(ndgrid(torch.arange(win_h), torch.arange(win_w))) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
Expand Down
9 changes: 4 additions & 5 deletions timm/models/swin_transformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
resample_patch_embed
resample_patch_embed, ndgrid
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
Expand Down Expand Up @@ -107,9 +107,8 @@ def __init__(
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([
relative_coords_h,
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
Expand All @@ -125,7 +124,7 @@ def __init__(
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords = torch.stack(ndgrid(coords_h, coords_w)) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
Expand Down
7 changes: 4 additions & 3 deletions timm/models/swin_transformer_v2_cr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import torch.utils.checkpoint as checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
Expand Down Expand Up @@ -141,9 +141,10 @@ def __init__(
def _make_pair_wise_relative_positions(self) -> None:
"""Method initializes the pair-wise relative positions to compute the positional biases."""
device = self.logit_scale.device
coordinates = torch.stack(torch.meshgrid([
coordinates = torch.stack(ndgrid(
torch.arange(self.window_size[0], device=device),
torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1)
torch.arange(self.window_size[1], device=device)
), dim=0).flatten(1)
relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :]
relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float()
relative_coordinates_log = torch.sign(relative_coordinates) * torch.log(
Expand Down

0 comments on commit 88889de

Please sign in to comment.