diff --git a/docs/notebooks/custom_convolutions.ipynb b/docs/notebooks/custom_convolutions.ipynb index ee7eba2..0c23f7b 100644 --- a/docs/notebooks/custom_convolutions.ipynb +++ b/docs/notebooks/custom_convolutions.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -76,7 +76,7 @@ "\n", "class CustomConv2D(sk.nn.Conv2D):\n", " # override the conv_op\n", - " conv_op = my_conv\n", + " conv_op = staticmethod(my_conv)\n", "\n", "\n", "k1, k2 = jr.split(jr.PRNGKey(0), 2)\n", @@ -124,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -165,7 +165,7 @@ "\n", "class CustomDepthwiseConv2D(sk.nn.DepthwiseConv2D):\n", " # override the conv_op\n", - " conv_op = my_depthwise_conv\n", + " conv_op = staticmethod(my_depthwise_conv)\n", "\n", "\n", "k1, k2 = jr.split(jr.PRNGKey(0), 2)\n", @@ -214,7 +214,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -223,7 +223,7 @@ "(2, 10, 10)" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -238,6 +238,7 @@ "\n", "\n", "def my_custom_conv(\n", + " self,\n", " input: jax.Array,\n", " weight: jax.Array,\n", " bias: jax.Array | None,\n", @@ -273,7 +274,7 @@ "\n", "class CustomConv2D(sk.nn.Conv2D):\n", " # override the conv_op\n", - " conv_op = my_custom_conv\n", + " conv_op = staticmethod(my_custom_conv)\n", "\n", "\n", "k1, k2 = jr.split(jr.PRNGKey(0), 2)\n", @@ -310,7 +311,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/serket/_src/image/filter.py b/serket/_src/image/filter.py index d49d4eb..a7e2974 100644 --- a/serket/_src/image/filter.py +++ b/serket/_src/image/filter.py @@ -770,10 +770,10 @@ def __init__(self, kernel_size: int | tuple[int, int]): def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) - return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args) + return jax.vmap(self.filter_op, in_axes=in_axes)(*args) spatial_ndim: int = 2 - filter_op = property(abc.abstractmethod(lambda _: ...)) + filter_op = staticmethod(abc.abstractmethod(lambda _: ...)) class AvgBlur2D(BaseAvgBlur2D): @@ -797,7 +797,7 @@ class AvgBlur2D(BaseAvgBlur2D): [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] """ - filter_op = avg_blur_2d + filter_op = staticmethod(avg_blur_2d) class FFTAvgBlur2D(BaseAvgBlur2D): @@ -821,7 +821,7 @@ class FFTAvgBlur2D(BaseAvgBlur2D): [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] """ - filter_op = fft_avg_blur_2d + filter_op = staticmethod(fft_avg_blur_2d) class BaseGaussianBlur2D(sk.TreeClass): @@ -839,10 +839,10 @@ def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None) sigma = jax.lax.stop_gradient(self.sigma) args = (image, self.kernel_size, sigma) - return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args) + return jax.vmap(self.filter_op, in_axes=in_axes)(*args) spatial_ndim: int = 2 - filter_op = property(abc.abstractmethod(lambda _: ...)) + filter_op = staticmethod(abc.abstractmethod(lambda _: ...)) class GaussianBlur2D(BaseGaussianBlur2D): @@ -867,7 +867,7 @@ class GaussianBlur2D(BaseGaussianBlur2D): [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] """ - filter_op = gaussian_blur_2d + filter_op = staticmethod(gaussian_blur_2d) class FFTGaussianBlur2D(BaseGaussianBlur2D): @@ -892,7 +892,7 @@ class FFTGaussianBlur2D(BaseGaussianBlur2D): [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] """ - filter_op = fft_gaussian_blur_2d + filter_op = staticmethod(fft_gaussian_blur_2d) class UnsharpMask2D(BaseGaussianBlur2D): @@ -917,7 +917,7 @@ class UnsharpMask2D(BaseGaussianBlur2D): [1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]]] """ - filter_op = unsharp_mask_2d + filter_op = staticmethod(unsharp_mask_2d) class FFTUnsharpMask2D(BaseGaussianBlur2D): @@ -942,7 +942,7 @@ class FFTUnsharpMask2D(BaseGaussianBlur2D): [1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]]] """ - filter_op = fft_unsharp_mask_2d + filter_op = staticmethod(fft_unsharp_mask_2d) class BoxBlur2DBase(sk.TreeClass): @@ -953,10 +953,10 @@ def __init__(self, kernel_size: int | tuple[int, int]): def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) - return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args) + return jax.vmap(self.filter_op, in_axes=in_axes)(*args) spatial_ndim: int = 2 - filter_op = property(abc.abstractmethod(lambda _: ...)) + filter_op = staticmethod(abc.abstractmethod(lambda _: ...)) class BoxBlur2D(BoxBlur2DBase): @@ -980,7 +980,7 @@ class BoxBlur2D(BoxBlur2DBase): [0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]]] """ - filter_op = box_blur_2d + filter_op = staticmethod(box_blur_2d) class FFTBoxBlur2D(BoxBlur2DBase): @@ -1004,7 +1004,7 @@ class FFTBoxBlur2D(BoxBlur2DBase): [0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]]] """ - filter_op = fft_box_blur_2d + filter_op = staticmethod(fft_box_blur_2d) class Laplacian2DBase(sk.TreeClass): @@ -1015,10 +1015,10 @@ def __init__(self, kernel_size: int | tuple[int, int]): def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) - return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args) + return jax.vmap(self.filter_op, in_axes=in_axes)(*args) spatial_ndim: int = 2 - filter_op = property(abc.abstractmethod(lambda _: ...)) + filter_op = staticmethod(abc.abstractmethod(lambda _: ...)) class Laplacian2D(Laplacian2DBase): @@ -1045,7 +1045,7 @@ class Laplacian2D(Laplacian2DBase): The laplacian considers all the neighbors of a pixel. """ - filter_op = laplacian_2d + filter_op = staticmethod(laplacian_2d) class FFTLaplacian2D(Laplacian2DBase): @@ -1072,7 +1072,7 @@ class FFTLaplacian2D(Laplacian2DBase): The laplacian considers all the neighbors of a pixel. """ - filter_op = fft_laplacian_2d + filter_op = staticmethod(fft_laplacian_2d) class MotionBlur2DBase(sk.TreeClass): @@ -1092,10 +1092,10 @@ def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None, None) angle, direction = jax.lax.stop_gradient((self.angle, self.direction)) args = (image, self.kernel_size, angle, direction) - return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args) + return jax.vmap(self.filter_op, in_axes=in_axes)(*args) spatial_ndim: int = 2 - filter_op = property(abc.abstractmethod(lambda _: ...)) + filter_op = staticmethod(abc.abstractmethod(lambda _: ...)) class MotionBlur2D(MotionBlur2DBase): @@ -1120,7 +1120,7 @@ class MotionBlur2D(MotionBlur2DBase): [ 6.472714 10.020969 10.770187 9.100007 ]]] """ - filter_op = motion_blur_2d + filter_op = staticmethod(motion_blur_2d) class FFTMotionBlur2D(MotionBlur2DBase): @@ -1145,7 +1145,7 @@ class FFTMotionBlur2D(MotionBlur2DBase): [ 6.472714 10.020969 10.770187 9.100007 ]]] """ - filter_op = fft_motion_blur_2d + filter_op = staticmethod(fft_motion_blur_2d) class MedianBlur2D(sk.TreeClass): @@ -1189,10 +1189,10 @@ def __call__(self, image: CHWArray) -> CHWArray: class Sobel2DBase(sk.TreeClass): @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: - return jax.vmap(type(self).filter_op)(image) + return jax.vmap(self.filter_op)(image) spatial_ndim: int = 2 - filter_op = property(abc.abstractmethod(lambda _: ...)) + filter_op = staticmethod(abc.abstractmethod(lambda _: ...)) class Sobel2D(Sobel2DBase): @@ -1216,7 +1216,7 @@ class Sobel2D(Sobel2DBase): [78.24321 , 68.26419 , 72.249565, 76.23647 , 89.27486 ]]] """ - filter_op = sobel_2d + filter_op = staticmethod(sobel_2d) class FFTSobel2D(Sobel2DBase): @@ -1237,7 +1237,7 @@ class FFTSobel2D(Sobel2DBase): [78.24321 , 68.26419 , 72.249565, 76.23647 , 89.27486 ]]] """ - filter_op = fft_sobel_2d + filter_op = staticmethod(fft_sobel_2d) class ElasticTransform2DBase(sk.TreeClass): @@ -1255,9 +1255,9 @@ def __init__( def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: in_axes = (None, 0, None, None, None) args = (image, self.kernel_size, self.sigma, self.alpha) - return jax.vmap(type(self).filter_op, in_axes=in_axes)(key, *args) + return jax.vmap(self.filter_op, in_axes=in_axes)(key, *args) - filter_op = property(abc.abstractmethod(lambda _: ...)) + filter_op = staticmethod(abc.abstractmethod(lambda _: ...)) spatial_ndim: int = 2 @@ -1288,7 +1288,7 @@ class ElasticTransform2D(ElasticTransform2DBase): [21. 21.659977 21.43855 21.138866 22.583244 ]]] """ - filter_op = elastic_transform_2d + filter_op = staticmethod(elastic_transform_2d) class FFTElasticTransform2D(ElasticTransform2DBase): @@ -1318,7 +1318,7 @@ class FFTElasticTransform2D(ElasticTransform2DBase): [21. 21.659977 21.43855 21.138866 22.583244 ]]] """ - filter_op = fft_elastic_transform_2d + filter_op = staticmethod(fft_elastic_transform_2d) class BilateralBlur2D(sk.TreeClass): @@ -1427,10 +1427,10 @@ def __init__( def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None) args = (image, self.kernel_size, self.strides) - return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args) + return jax.vmap(self.filter_op, in_axes=in_axes)(*args) spatial_ndim: int = 2 - filter_op = property(abc.abstractmethod(lambda _: ...)) + filter_op = staticmethod(abc.abstractmethod(lambda _: ...)) class BlurPool2D(BlurPool2DBase): @@ -1453,7 +1453,7 @@ class BlurPool2D(BlurPool2DBase): [11.0625 16. 12.9375]]] """ - filter_op = blur_pool_2d + filter_op = staticmethod(blur_pool_2d) class FFTBlurPool2D(BlurPool2DBase): @@ -1476,4 +1476,4 @@ class FFTBlurPool2D(BlurPool2DBase): [11.0625 16. 12.9375]]] """ - filter_op = fft_blur_pool_2d + filter_op = staticmethod(fft_blur_pool_2d) diff --git a/serket/_src/nn/convolution.py b/serket/_src/nn/convolution.py index 8ea274b..1f99e86 100644 --- a/serket/_src/nn/convolution.py +++ b/serket/_src/nn/convolution.py @@ -608,7 +608,7 @@ def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: strides=self.strides, ) - return type(self).conv_op( + return self.conv_op( input=input, weight=self.weight, bias=self.bias, @@ -620,7 +620,7 @@ def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: ) spatial_ndim = property(abc.abstractmethod(lambda _: ...)) - conv_op = property(abc.abstractmethod(lambda _: ...)) + conv_op = staticmethod(abc.abstractmethod(lambda _: ...)) class Conv1D(ConvND): @@ -708,7 +708,7 @@ class Conv1D(ConvND): """ spatial_ndim: int = 1 - conv_op = conv_nd + conv_op = staticmethod(conv_nd) class Conv2D(ConvND): @@ -796,7 +796,7 @@ class Conv2D(ConvND): """ spatial_ndim: int = 2 - conv_op = conv_nd + conv_op = staticmethod(conv_nd) class Conv3D(ConvND): @@ -884,7 +884,7 @@ class Conv3D(ConvND): """ spatial_ndim: int = 3 - conv_op = conv_nd + conv_op = staticmethod(conv_nd) class FFTConv1D(ConvND): @@ -972,7 +972,7 @@ class FFTConv1D(ConvND): """ spatial_ndim: int = 1 - conv_op = fft_conv_nd + conv_op = staticmethod(fft_conv_nd) class FFTConv2D(ConvND): @@ -1060,7 +1060,7 @@ class FFTConv2D(ConvND): """ spatial_ndim: int = 2 - conv_op = fft_conv_nd + conv_op = staticmethod(fft_conv_nd) class FFTConv3D(ConvND): @@ -1148,7 +1148,7 @@ class FFTConv3D(ConvND): """ spatial_ndim: int = 3 - conv_op = fft_conv_nd + conv_op = staticmethod(fft_conv_nd) class ConvNDTranspose(sk.TreeClass): @@ -1212,7 +1212,7 @@ def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: strides=self.strides, ) - return type(self).conv_op( + return self.conv_op( input=input, weight=self.weight, bias=self.bias, @@ -1224,7 +1224,7 @@ def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: ) spatial_ndim = property(abc.abstractmethod(lambda _: ...)) - conv_op = property(abc.abstractmethod(lambda _: ...)) + conv_op = staticmethod(abc.abstractmethod(lambda _: ...)) class Conv1DTranspose(ConvNDTranspose): @@ -1316,7 +1316,7 @@ class Conv1DTranspose(ConvNDTranspose): """ spatial_ndim: int = 1 - conv_op = conv_nd_transpose + conv_op = staticmethod(conv_nd_transpose) class Conv2DTranspose(ConvNDTranspose): @@ -1407,7 +1407,7 @@ class Conv2DTranspose(ConvNDTranspose): """ spatial_ndim: int = 2 - conv_op = conv_nd_transpose + conv_op = staticmethod(conv_nd_transpose) class Conv3DTranspose(ConvNDTranspose): @@ -1499,7 +1499,7 @@ class Conv3DTranspose(ConvNDTranspose): """ spatial_ndim: int = 3 - conv_op = conv_nd_transpose + conv_op = staticmethod(conv_nd_transpose) class FFTConv1DTranspose(ConvNDTranspose): @@ -1591,7 +1591,7 @@ class FFTConv1DTranspose(ConvNDTranspose): """ spatial_ndim: int = 1 - conv_op = fft_conv_nd_transpose + conv_op = staticmethod(fft_conv_nd_transpose) class FFTConv2DTranspose(ConvNDTranspose): @@ -1683,7 +1683,7 @@ class FFTConv2DTranspose(ConvNDTranspose): """ spatial_ndim: int = 2 - conv_op = fft_conv_nd_transpose + conv_op = staticmethod(fft_conv_nd_transpose) class FFTConv3DTranspose(ConvNDTranspose): @@ -1775,7 +1775,7 @@ class FFTConv3DTranspose(ConvNDTranspose): """ spatial_ndim: int = 3 - conv_op = fft_conv_nd_transpose + conv_op = staticmethod(fft_conv_nd_transpose) class DepthwiseConvND(sk.TreeClass): @@ -1828,7 +1828,7 @@ def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: strides=self.strides, ) - return type(self).conv_op( + return self.conv_op( input=input, weight=self.weight, bias=self.bias, @@ -1838,7 +1838,7 @@ def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: ) spatial_ndim = property(abc.abstractmethod(lambda _: ...)) - conv_op = property(abc.abstractmethod(lambda _: ...)) + conv_op = staticmethod(abc.abstractmethod(lambda _: ...)) class DepthwiseConv1D(DepthwiseConvND): @@ -1913,7 +1913,7 @@ class DepthwiseConv1D(DepthwiseConvND): """ spatial_ndim: int = 1 - conv_op = depthwise_conv_nd + conv_op = staticmethod(depthwise_conv_nd) class DepthwiseConv2D(DepthwiseConvND): @@ -1988,7 +1988,7 @@ class DepthwiseConv2D(DepthwiseConvND): """ spatial_ndim: int = 2 - conv_op = depthwise_conv_nd + conv_op = staticmethod(depthwise_conv_nd) class DepthwiseConv3D(DepthwiseConvND): @@ -2063,7 +2063,7 @@ class DepthwiseConv3D(DepthwiseConvND): """ spatial_ndim: int = 3 - conv_op = depthwise_conv_nd + conv_op = staticmethod(depthwise_conv_nd) class DepthwiseFFTConv1D(DepthwiseConvND): @@ -2138,7 +2138,7 @@ class DepthwiseFFTConv1D(DepthwiseConvND): """ spatial_ndim: int = 1 - conv_op = depthwise_fft_conv_nd + conv_op = staticmethod(depthwise_fft_conv_nd) class DepthwiseFFTConv2D(DepthwiseConvND): @@ -2213,7 +2213,7 @@ class DepthwiseFFTConv2D(DepthwiseConvND): """ spatial_ndim: int = 2 - conv_op = depthwise_fft_conv_nd + conv_op = staticmethod(depthwise_fft_conv_nd) class DepthwiseFFTConv3D(DepthwiseConvND): @@ -2288,7 +2288,7 @@ class DepthwiseFFTConv3D(DepthwiseConvND): """ spatial_ndim: int = 3 - conv_op = depthwise_fft_conv_nd + conv_op = staticmethod(depthwise_fft_conv_nd) class SeparableConvND(sk.TreeClass): @@ -2365,7 +2365,7 @@ def __call__( strides=self.strides, ) - return type(self).conv_op( + return self.conv_op( input=input, depthwise_weight=self.depthwise_weight, pointwise_weight=self.pointwise_weight, @@ -2378,7 +2378,7 @@ def __call__( ) spatial_ndim = property(abc.abstractmethod(lambda _: ...)) - conv_op = property(abc.abstractmethod(lambda _: ...)) + conv_op = staticmethod(abc.abstractmethod(lambda _: ...)) class SeparableConv1D(SeparableConvND): @@ -2464,7 +2464,7 @@ class SeparableConv1D(SeparableConvND): """ spatial_ndim: int = 1 - conv_op = separable_conv_nd + conv_op = staticmethod(separable_conv_nd) class SeparableConv2D(SeparableConvND): @@ -2550,7 +2550,7 @@ class SeparableConv2D(SeparableConvND): """ spatial_ndim: int = 2 - conv_op = separable_conv_nd + conv_op = staticmethod(separable_conv_nd) class SeparableConv3D(SeparableConvND): @@ -2636,7 +2636,7 @@ class SeparableConv3D(SeparableConvND): """ spatial_ndim: int = 3 - conv_op = separable_conv_nd + conv_op = staticmethod(separable_conv_nd) class SeparableFFTConv1D(SeparableConvND): @@ -2722,7 +2722,7 @@ class SeparableFFTConv1D(SeparableConvND): """ spatial_ndim: int = 1 - conv_op = separable_fft_conv_nd + conv_op = staticmethod(separable_fft_conv_nd) class SeparableFFTConv2D(SeparableConvND): @@ -2808,7 +2808,7 @@ class SeparableFFTConv2D(SeparableConvND): """ spatial_ndim: int = 2 - conv_op = separable_fft_conv_nd + conv_op = staticmethod(separable_fft_conv_nd) class SeparableFFTConv3D(SeparableConvND): @@ -2894,7 +2894,7 @@ class SeparableFFTConv3D(SeparableConvND): """ spatial_ndim: int = 3 - conv_op = separable_fft_conv_nd + conv_op = staticmethod(separable_fft_conv_nd) class SpectralConvND(sk.TreeClass): @@ -2921,7 +2921,7 @@ def __init__( @ft.partial(validate_spatial_ndim, argnum=0) @ft.partial(validate_in_features_shape, axis=0) def __call__(self, input: jax.Array) -> jax.Array: - return type(self).conv_op( + return self.conv_op( input=input, weight_r=self.weight_r, weight_i=self.weight_i, @@ -2929,7 +2929,7 @@ def __call__(self, input: jax.Array) -> jax.Array: ) spatial_ndim = property(abc.abstractmethod(lambda _: ...)) - conv_op = property(abc.abstractmethod(lambda _: ...)) + conv_op = staticmethod(abc.abstractmethod(lambda _: ...)) class SpectralConv1D(SpectralConvND): @@ -2986,7 +2986,7 @@ class SpectralConv1D(SpectralConvND): """ spatial_ndim: int = 1 - conv_op = spectral_conv_nd + conv_op = staticmethod(spectral_conv_nd) class SpectralConv2D(SpectralConvND): @@ -3044,7 +3044,7 @@ class SpectralConv2D(SpectralConvND): """ spatial_ndim: int = 2 - conv_op = spectral_conv_nd + conv_op = staticmethod(spectral_conv_nd) class SpectralConv3D(SpectralConvND): @@ -3102,7 +3102,7 @@ class SpectralConv3D(SpectralConvND): """ spatial_ndim: int = 3 - conv_op = spectral_conv_nd + conv_op = staticmethod(spectral_conv_nd) def is_lazy_call(instance, *_, **__) -> bool: @@ -3205,7 +3205,7 @@ def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: ``(out_features, in_features * prod(kernel_size), *out_size)`` use ``None`` for no mask. """ - return type(self).conv_op( + return self.conv_op( input=input, weight=self.weight, bias=self.bias, @@ -3217,7 +3217,7 @@ def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: ) spatial_ndim = property(abc.abstractmethod(lambda _: ...)) - conv_op = property(abc.abstractmethod(lambda _: ...)) + conv_op = staticmethod(abc.abstractmethod(lambda _: ...)) class Conv1DLocal(ConvNDLocal): @@ -3299,7 +3299,7 @@ class Conv1DLocal(ConvNDLocal): """ spatial_ndim: int = 1 - conv_op = local_conv_nd + conv_op = staticmethod(local_conv_nd) class Conv2DLocal(ConvNDLocal): @@ -3381,7 +3381,7 @@ class Conv2DLocal(ConvNDLocal): """ spatial_ndim: int = 2 - conv_op = local_conv_nd + conv_op = staticmethod(local_conv_nd) class Conv3DLocal(ConvNDLocal): @@ -3463,4 +3463,4 @@ class Conv3DLocal(ConvNDLocal): """ spatial_ndim: int = 3 - conv_op = local_conv_nd + conv_op = staticmethod(local_conv_nd)