-
Notifications
You must be signed in to change notification settings - Fork 0
/
pixel_unshuffle.py
33 lines (28 loc) · 1.1 KB
/
pixel_unshuffle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import torch.nn.functional as F
def pixel_unshuffle(input, downscale_factor):
'''
input: batchSize * c * k*w * k*h
kdownscale_factor: k
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
'''
c = input.shape[1]
kernel = torch.zeros(size=[downscale_factor * downscale_factor * c,
1, downscale_factor, downscale_factor],
device=input.device)
for y in range(downscale_factor):
for x in range(downscale_factor):
kernel[x + y * downscale_factor::downscale_factor*downscale_factor, 0, y, x] = 1
return F.conv2d(input, kernel, stride=downscale_factor, groups=c)
class PixelUnshuffle(nn.Module):
def __init__(self, downscale_factor):
super(PixelUnshuffle, self).__init__()
self.downscale_factor = downscale_factor
def forward(self, input):
'''
input: batchSize * c * k*w * k*h
kdownscale_factor: k
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
'''
return pixel_unshuffle(input, self.downscale_factor)