-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDWT_IDWT_Functions.py
executable file
·168 lines (153 loc) · 10.2 KB
/
DWT_IDWT_Functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (c) 2019, Adobe Inc. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
# 4.0 International Public License. To view a copy of this license, visit
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
"""
自定义pytorch函数,实现一维、二维、三维张量的DWT和IDWT,未考虑边界延拓
只有当图像行列数都是偶数,且重构滤波器组低频分量长度为2时,才能精确重构,否则在边界处有误差。
"""
import torch
from torch.autograd import Function
class DWTFunction_1D(Function):
@staticmethod
def forward(ctx, input, matrix_Low, matrix_High):
ctx.save_for_backward(matrix_Low, matrix_High)
L = torch.matmul(input, matrix_Low.t())
H = torch.matmul(input, matrix_High.t())
return L, H
@staticmethod
def backward(ctx, grad_L, grad_H):
matrix_L, matrix_H = ctx.saved_variables
grad_input = torch.add(torch.matmul(grad_L, matrix_L), torch.matmul(grad_H, matrix_H))
return grad_input, None, None
class IDWTFunction_1D(Function):
@staticmethod
def forward(ctx, input_L, input_H, matrix_L, matrix_H):
ctx.save_for_backward(matrix_L, matrix_H)
output = torch.add(torch.matmul(input_L, matrix_L), torch.matmul(input_H, matrix_H))
return output
@staticmethod
def backward(ctx, grad_output):
matrix_L, matrix_H = ctx.saved_variables
grad_L = torch.matmul(grad_output, matrix_L.t())
grad_H = torch.matmul(grad_output, matrix_H.t())
return grad_L, grad_H, None, None
class DWTFunction_2D(Function):
@staticmethod
def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1)
L = torch.matmul(matrix_Low_0, input)
H = torch.matmul(matrix_High_0, input)
LL = torch.matmul(L, matrix_Low_1)
LH = torch.matmul(L, matrix_High_1)
HL = torch.matmul(H, matrix_Low_1)
HH = torch.matmul(H, matrix_High_1)
return LL, LH, HL, HH
@staticmethod
def backward(ctx, grad_LL, grad_LH, grad_HL, grad_HH):
matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_variables
grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()), torch.matmul(grad_LH, matrix_High_1.t()))
grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()), torch.matmul(grad_HH, matrix_High_1.t()))
grad_input = torch.add(torch.matmul(matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H))
return grad_input, None, None, None, None
class DWTFunction_2D_tiny(Function):
@staticmethod
def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1)
L = torch.matmul(matrix_Low_0, input)
LL = torch.matmul(L, matrix_Low_1)
return LL
@staticmethod
def backward(ctx, grad_LL):
matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_variables
grad_L = torch.matmul(grad_LL, matrix_Low_1.t())
grad_input = torch.matmul(matrix_Low_0.t(), grad_L)
return grad_input, None, None, None, None
class IDWTFunction_2D(Function):
@staticmethod
def forward(ctx, input_LL, input_LH, input_HL, input_HH,
matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1)
L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()), torch.matmul(input_LH, matrix_High_1.t()))
H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()), torch.matmul(input_HH, matrix_High_1.t()))
output = torch.add(torch.matmul(matrix_Low_0.t(), L), torch.matmul(matrix_High_0.t(), H))
return output
@staticmethod
def backward(ctx, grad_output):
matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_variables
grad_L = torch.matmul(matrix_Low_0, grad_output)
grad_H = torch.matmul(matrix_High_0, grad_output)
grad_LL = torch.matmul(grad_L, matrix_Low_1)
grad_LH = torch.matmul(grad_L, matrix_High_1)
grad_HL = torch.matmul(grad_H, matrix_Low_1)
grad_HH = torch.matmul(grad_H, matrix_High_1)
return grad_LL, grad_LH, grad_HL, grad_HH, None, None, None, None
class DWTFunction_3D(Function):
@staticmethod
def forward(ctx, input,
matrix_Low_0, matrix_Low_1, matrix_Low_2,
matrix_High_0, matrix_High_1, matrix_High_2):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,
matrix_High_0, matrix_High_1, matrix_High_2)
L = torch.matmul(matrix_Low_0, input)
H = torch.matmul(matrix_High_0, input)
LL = torch.matmul(L, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)
LH = torch.matmul(L, matrix_High_1).transpose(dim0 = 2, dim1 = 3)
HL = torch.matmul(H, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)
HH = torch.matmul(H, matrix_High_1).transpose(dim0 = 2, dim1 = 3)
LLL = torch.matmul(matrix_Low_2, LL).transpose(dim0 = 2, dim1 = 3)
LLH = torch.matmul(matrix_Low_2, LH).transpose(dim0 = 2, dim1 = 3)
LHL = torch.matmul(matrix_Low_2, HL).transpose(dim0 = 2, dim1 = 3)
LHH = torch.matmul(matrix_Low_2, HH).transpose(dim0 = 2, dim1 = 3)
HLL = torch.matmul(matrix_High_2, LL).transpose(dim0 = 2, dim1 = 3)
HLH = torch.matmul(matrix_High_2, LH).transpose(dim0 = 2, dim1 = 3)
HHL = torch.matmul(matrix_High_2, HL).transpose(dim0 = 2, dim1 = 3)
HHH = torch.matmul(matrix_High_2, HH).transpose(dim0 = 2, dim1 = 3)
return LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH
@staticmethod
def backward(ctx, grad_LLL, grad_LLH, grad_LHL, grad_LHH,
grad_HLL, grad_HLH, grad_HHL, grad_HHH):
matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables
grad_LL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HLL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
grad_LH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HLH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
grad_HL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HHL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
grad_HH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HHH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()), torch.matmul(grad_LH, matrix_High_1.t()))
grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()), torch.matmul(grad_HH, matrix_High_1.t()))
grad_input = torch.add(torch.matmul(matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H))
return grad_input, None, None, None, None, None, None, None, None
class IDWTFunction_3D(Function):
@staticmethod
def forward(ctx, input_LLL, input_LLH, input_LHL, input_LHH,
input_HLL, input_HLH, input_HHL, input_HHH,
matrix_Low_0, matrix_Low_1, matrix_Low_2,
matrix_High_0, matrix_High_1, matrix_High_2):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,
matrix_High_0, matrix_High_1, matrix_High_2)
input_LL = torch.add(torch.matmul(matrix_Low_2.t(), input_LLL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HLL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
input_LH = torch.add(torch.matmul(matrix_Low_2.t(), input_LLH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HLH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
input_HL = torch.add(torch.matmul(matrix_Low_2.t(), input_LHL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HHL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
input_HH = torch.add(torch.matmul(matrix_Low_2.t(), input_LHH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HHH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
input_L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()), torch.matmul(input_LH, matrix_High_1.t()))
input_H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()), torch.matmul(input_HH, matrix_High_1.t()))
output = torch.add(torch.matmul(matrix_Low_0.t(), input_L), torch.matmul(matrix_High_0.t(), input_H))
return output
@staticmethod
def backward(ctx, grad_output):
matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables
grad_L = torch.matmul(matrix_Low_0, grad_output)
grad_H = torch.matmul(matrix_High_0, grad_output)
grad_LL = torch.matmul(grad_L, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)
grad_LH = torch.matmul(grad_L, matrix_High_1).transpose(dim0 = 2, dim1 = 3)
grad_HL = torch.matmul(grad_H, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)
grad_HH = torch.matmul(grad_H, matrix_High_1).transpose(dim0 = 2, dim1 = 3)
grad_LLL = torch.matmul(matrix_Low_2, grad_LL).transpose(dim0 = 2, dim1 = 3)
grad_LLH = torch.matmul(matrix_Low_2, grad_LH).transpose(dim0 = 2, dim1 = 3)
grad_LHL = torch.matmul(matrix_Low_2, grad_HL).transpose(dim0 = 2, dim1 = 3)
grad_LHH = torch.matmul(matrix_Low_2, grad_HH).transpose(dim0 = 2, dim1 = 3)
grad_HLL = torch.matmul(matrix_High_2, grad_LL).transpose(dim0 = 2, dim1 = 3)
grad_HLH = torch.matmul(matrix_High_2, grad_LH).transpose(dim0 = 2, dim1 = 3)
grad_HHL = torch.matmul(matrix_High_2, grad_HL).transpose(dim0 = 2, dim1 = 3)
grad_HHH = torch.matmul(matrix_High_2, grad_HH).transpose(dim0 = 2, dim1 = 3)
return grad_LLL, grad_LLH, grad_LHL, grad_LHH, grad_HLL, grad_HLH, grad_HHL, grad_HHH, None, None, None, None, None, None