-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlap_pyramid.py
32 lines (27 loc) · 1.13 KB
/
lap_pyramid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch.nn as nn
import torch.nn.functional as F
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Lap_Pyramid_Bicubic(nn.Module):
"""
"""
def __init__(self, num_high=3):
super(Lap_Pyramid_Bicubic, self).__init__()
self.interpolate_mode = 'bicubic'
self.num_high = num_high
def pyramid_decom(self, img):
current = img
pyr = []
for i in range(self.num_high):
down = nn.functional.interpolate(current, size=(current.shape[2] // 2, current.shape[3] // 2), mode=self.interpolate_mode, align_corners=True)
up = nn.functional.interpolate(down, size=(current.shape[2], current.shape[3]), mode=self.interpolate_mode, align_corners=True)
diff = current - up
pyr.append(diff)
current = down
pyr.append(current)
return pyr
def pyramid_recons(self, pyr):
image = pyr[-1]
for level in reversed(pyr[:-1]):
image = F.interpolate(image, size=(level.shape[2], level.shape[3]), mode=self.interpolate_mode, align_corners=True) + level
return image