From 3eb07798116caef951df935ecae92dfb698ccc8b Mon Sep 17 00:00:00 2001 From: Fergal Cotter Date: Thu, 3 Oct 2019 22:35:21 +0100 Subject: [PATCH] Fixed expand dims bug --- pytorch_wavelets/dtcwt/lowlevel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_wavelets/dtcwt/lowlevel.py b/pytorch_wavelets/dtcwt/lowlevel.py index 0112719..88fa267 100644 --- a/pytorch_wavelets/dtcwt/lowlevel.py +++ b/pytorch_wavelets/dtcwt/lowlevel.py @@ -59,8 +59,7 @@ def prep_filt(h, c, transpose=False): """ Prepares an array to be of the correct format for pytorch. Can also specify whether to make it a row filter (set tranpose=True)""" h = _as_col_vector(h)[::-1] - #h = np.reshape(h, [1, 1, *h.shape]) - h = np.expand_dims(h, (0,1)) + h = h[None, None, :] h = np.repeat(h, repeats=c, axis=0) if transpose: h = h.transpose((0,1,3,2))