diff --git a/gconv/gnn/modules/gconv.py b/gconv/gnn/modules/gconv.py index 0628ca7..9a1c30d 100644 --- a/gconv/gnn/modules/gconv.py +++ b/gconv/gnn/modules/gconv.py @@ -102,13 +102,13 @@ def __init__( if conv_mode == "2d": self._conv_forward = self._conv2d_forward bias_shape = (1, 1, 1) - self.padding = ( - _triple(padding) if isinstance(self.padding, int) else padding - ) + self.padding = _pair(padding) if isinstance(self.padding, int) else padding elif conv_mode == "3d": self._conv_forward = self._conv3d_forward bias_shape = (1, 1, 1, 1) - self.padding = _pair(padding) if isinstance(self.padding, int) else padding + self.padding = ( + _triple(padding) if isinstance(self.padding, int) else padding + ) else: raise ValueError(