From eccb7792733ee59cfbf465820b5a53768d8e1307 Mon Sep 17 00:00:00 2001 From: ge26yim Date: Fri, 10 May 2024 13:53:04 +0200 Subject: [PATCH 1/3] fix DCT Type-II --- dct_transform.py | 12 +++++++++--- inn_architecture.py | 2 ++ train.py | 2 ++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/dct_transform.py b/dct_transform.py index a3bf601..b9f6739 100644 --- a/dct_transform.py +++ b/dct_transform.py @@ -1,6 +1,7 @@ import numpy as np import torch import torch.nn as nn +import pdb '''adapted from https://github.com/zh217/torch-dct''' @@ -11,18 +12,25 @@ def dct_1d(x): :param x: the input signal :return: the DCT-II of the signal over the last dimension """ + x_shape = x.shape N = x_shape[-1] + + # reshape the input signal to a 2D tensor with the last dim flattened x = x.contiguous().view(-1, N) + x = x.unsqueeze(0) + # rearrange cols v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) - Vc = torch.rfft(v, 1, onesided=False) + # Apply 1-D real-to-complex Fast Fourier Transform along last dim + Vc = torch.view_as_real( torch.fft.fft(v, dim=1)) k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) W_r = torch.cos(k) W_i = torch.sin(k) + # Apply DCT formula V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i V = 2 * V.view(*x_shape) @@ -124,5 +132,3 @@ def output_dims(self, input_dims): err = torch.abs(x - x_inv).max() print(N, err.item(), flush=True) - - diff --git a/inn_architecture.py b/inn_architecture.py index 983acb6..a5c7864 100644 --- a/inn_architecture.py +++ b/inn_architecture.py @@ -5,6 +5,7 @@ import torch.nn as nn from torch.nn.functional import conv2d, interpolate import numpy as np +import pdb import FrEIA.framework as Ff import FrEIA.modules as Fm @@ -123,6 +124,7 @@ def fc_constr(c_in, c_out): channels = classifier.input_channels if classifier.dataset == 'MNIST': + # import pdb; pdb.set_trace() nodes.append(Ff.Node(nodes[-1].out0, Fm.Reshape, {'target_dim':(1, *classifier.dims)})) nodes.append(Ff.Node(nodes[-1].out0, Fm.HaarDownsampling, {'rebalance':1.})) channels *= 4 diff --git a/train.py b/train.py index bd024a2..a943c71 100644 --- a/train.py +++ b/train.py @@ -97,6 +97,8 @@ def log_write(line, endline='\n'): for i_batch, (x,l) in enumerate(dataset.train_loader): + import pdb; pdb.set_trace() + x, y = x.cuda(), dataset.onehot(l.cuda(), label_smoothing) losses = inn(x, y) From 4ad12d25253ff3494fa0c9bf6d07100db67ba3b0 Mon Sep 17 00:00:00 2001 From: ge26yim Date: Fri, 10 May 2024 13:59:11 +0200 Subject: [PATCH 2/3] fix DCT Type-II --- dct_transform.py | 13 +++++++++---- inn_architecture.py | 2 ++ train.py | 2 ++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/dct_transform.py b/dct_transform.py index a3bf601..c6a8cde 100644 --- a/dct_transform.py +++ b/dct_transform.py @@ -1,6 +1,7 @@ import numpy as np import torch import torch.nn as nn +import pdb '''adapted from https://github.com/zh217/torch-dct''' @@ -11,18 +12,24 @@ def dct_1d(x): :param x: the input signal :return: the DCT-II of the signal over the last dimension """ + x_shape = x.shape N = x_shape[-1] + + # reshape the input signal to a 2D tensor with the last dim flattened x = x.contiguous().view(-1, N) - v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) + # rearrange cols + v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=-1) - Vc = torch.rfft(v, 1, onesided=False) + # Apply 1-D real-to-complex Fast Fourier Transform along last dim + Vc = torch.view_as_real( torch.fft.fft(v, dim=1)) k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) W_r = torch.cos(k) W_i = torch.sin(k) + # Apply DCT formula V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i V = 2 * V.view(*x_shape) @@ -124,5 +131,3 @@ def output_dims(self, input_dims): err = torch.abs(x - x_inv).max() print(N, err.item(), flush=True) - - diff --git a/inn_architecture.py b/inn_architecture.py index 983acb6..a5c7864 100644 --- a/inn_architecture.py +++ b/inn_architecture.py @@ -5,6 +5,7 @@ import torch.nn as nn from torch.nn.functional import conv2d, interpolate import numpy as np +import pdb import FrEIA.framework as Ff import FrEIA.modules as Fm @@ -123,6 +124,7 @@ def fc_constr(c_in, c_out): channels = classifier.input_channels if classifier.dataset == 'MNIST': + # import pdb; pdb.set_trace() nodes.append(Ff.Node(nodes[-1].out0, Fm.Reshape, {'target_dim':(1, *classifier.dims)})) nodes.append(Ff.Node(nodes[-1].out0, Fm.HaarDownsampling, {'rebalance':1.})) channels *= 4 diff --git a/train.py b/train.py index bd024a2..a943c71 100644 --- a/train.py +++ b/train.py @@ -97,6 +97,8 @@ def log_write(line, endline='\n'): for i_batch, (x,l) in enumerate(dataset.train_loader): + import pdb; pdb.set_trace() + x, y = x.cuda(), dataset.onehot(l.cuda(), label_smoothing) losses = inn(x, y) From b79b90c4f784f295c5884919d5852494e31a8db3 Mon Sep 17 00:00:00 2001 From: ge26yim Date: Fri, 10 May 2024 14:04:17 +0200 Subject: [PATCH 3/3] fix inverse DCT-III --- dct_transform.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dct_transform.py b/dct_transform.py index 4659c68..fab449a 100644 --- a/dct_transform.py +++ b/dct_transform.py @@ -18,7 +18,6 @@ def dct_1d(x): # reshape the input signal to a 2D tensor with the last dim flattened x = x.contiguous().view(-1, N) - x = x.unsqueeze(0) # rearrange cols v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=-1) @@ -62,7 +61,7 @@ def idct_1d(X): V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) - v = torch.irfft(V, 1, onesided=False) + v = torch.fft.irfft(torch.view_as_complex(V),n=V.shape[1], dim=1) x = v.new_zeros(v.shape) x[:, ::2] += v[:, :N - (N // 2)] x[:, 1::2] += v.flip([1])[:, :N // 2]