Skip to content

Commit

Permalink
use static method instead of property to allow user defined function …
Browse files Browse the repository at this point in the history
…to access the instance
  • Loading branch information
ASEM000 committed Apr 1, 2024
1 parent 6d73bc5 commit 749a0ee
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 85 deletions.
19 changes: 10 additions & 9 deletions docs/notebooks/custom_convolutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -30,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -76,7 +76,7 @@
"\n",
"class CustomConv2D(sk.nn.Conv2D):\n",
" # override the conv_op\n",
" conv_op = my_conv\n",
" conv_op = staticmethod(my_conv)\n",
"\n",
"\n",
"k1, k2 = jr.split(jr.PRNGKey(0), 2)\n",
Expand Down Expand Up @@ -124,7 +124,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -165,7 +165,7 @@
"\n",
"class CustomDepthwiseConv2D(sk.nn.DepthwiseConv2D):\n",
" # override the conv_op\n",
" conv_op = my_depthwise_conv\n",
" conv_op = staticmethod(my_depthwise_conv)\n",
"\n",
"\n",
"k1, k2 = jr.split(jr.PRNGKey(0), 2)\n",
Expand Down Expand Up @@ -214,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -223,7 +223,7 @@
"(2, 10, 10)"
]
},
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -238,6 +238,7 @@
"\n",
"\n",
"def my_custom_conv(\n",
" self,\n",
" input: jax.Array,\n",
" weight: jax.Array,\n",
" bias: jax.Array | None,\n",
Expand Down Expand Up @@ -273,7 +274,7 @@
"\n",
"class CustomConv2D(sk.nn.Conv2D):\n",
" # override the conv_op\n",
" conv_op = my_custom_conv\n",
" conv_op = staticmethod(my_custom_conv)\n",
"\n",
"\n",
"k1, k2 = jr.split(jr.PRNGKey(0), 2)\n",
Expand Down Expand Up @@ -310,7 +311,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
"version": "3.12.2"
}
},
"nbformat": 4,
Expand Down
68 changes: 34 additions & 34 deletions serket/_src/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,10 +770,10 @@ def __init__(self, kernel_size: int | tuple[int, int]):
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None)
args = (image, self.kernel_size)
return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)

spatial_ndim: int = 2
filter_op = property(abc.abstractmethod(lambda _: ...))
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))


class AvgBlur2D(BaseAvgBlur2D):
Expand All @@ -797,7 +797,7 @@ class AvgBlur2D(BaseAvgBlur2D):
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
"""

filter_op = avg_blur_2d
filter_op = staticmethod(avg_blur_2d)


class FFTAvgBlur2D(BaseAvgBlur2D):
Expand All @@ -821,7 +821,7 @@ class FFTAvgBlur2D(BaseAvgBlur2D):
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
"""

filter_op = fft_avg_blur_2d
filter_op = staticmethod(fft_avg_blur_2d)


class BaseGaussianBlur2D(sk.TreeClass):
Expand All @@ -839,10 +839,10 @@ def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None, None)
sigma = jax.lax.stop_gradient(self.sigma)
args = (image, self.kernel_size, sigma)
return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)

spatial_ndim: int = 2
filter_op = property(abc.abstractmethod(lambda _: ...))
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))


class GaussianBlur2D(BaseGaussianBlur2D):
Expand All @@ -867,7 +867,7 @@ class GaussianBlur2D(BaseGaussianBlur2D):
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
"""

filter_op = gaussian_blur_2d
filter_op = staticmethod(gaussian_blur_2d)


class FFTGaussianBlur2D(BaseGaussianBlur2D):
Expand All @@ -892,7 +892,7 @@ class FFTGaussianBlur2D(BaseGaussianBlur2D):
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
"""

filter_op = fft_gaussian_blur_2d
filter_op = staticmethod(fft_gaussian_blur_2d)


class UnsharpMask2D(BaseGaussianBlur2D):
Expand All @@ -917,7 +917,7 @@ class UnsharpMask2D(BaseGaussianBlur2D):
[1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]]]
"""

filter_op = unsharp_mask_2d
filter_op = staticmethod(unsharp_mask_2d)


class FFTUnsharpMask2D(BaseGaussianBlur2D):
Expand All @@ -942,7 +942,7 @@ class FFTUnsharpMask2D(BaseGaussianBlur2D):
[1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]]]
"""

filter_op = fft_unsharp_mask_2d
filter_op = staticmethod(fft_unsharp_mask_2d)


class BoxBlur2DBase(sk.TreeClass):
Expand All @@ -953,10 +953,10 @@ def __init__(self, kernel_size: int | tuple[int, int]):
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None)
args = (image, self.kernel_size)
return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)

spatial_ndim: int = 2
filter_op = property(abc.abstractmethod(lambda _: ...))
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))


class BoxBlur2D(BoxBlur2DBase):
Expand All @@ -980,7 +980,7 @@ class BoxBlur2D(BoxBlur2DBase):
[0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]]]
"""

filter_op = box_blur_2d
filter_op = staticmethod(box_blur_2d)


class FFTBoxBlur2D(BoxBlur2DBase):
Expand All @@ -1004,7 +1004,7 @@ class FFTBoxBlur2D(BoxBlur2DBase):
[0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]]]
"""

filter_op = fft_box_blur_2d
filter_op = staticmethod(fft_box_blur_2d)


class Laplacian2DBase(sk.TreeClass):
Expand All @@ -1015,10 +1015,10 @@ def __init__(self, kernel_size: int | tuple[int, int]):
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None)
args = (image, self.kernel_size)
return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)

spatial_ndim: int = 2
filter_op = property(abc.abstractmethod(lambda _: ...))
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))


class Laplacian2D(Laplacian2DBase):
Expand All @@ -1045,7 +1045,7 @@ class Laplacian2D(Laplacian2DBase):
The laplacian considers all the neighbors of a pixel.
"""

filter_op = laplacian_2d
filter_op = staticmethod(laplacian_2d)


class FFTLaplacian2D(Laplacian2DBase):
Expand All @@ -1072,7 +1072,7 @@ class FFTLaplacian2D(Laplacian2DBase):
The laplacian considers all the neighbors of a pixel.
"""

filter_op = fft_laplacian_2d
filter_op = staticmethod(fft_laplacian_2d)


class MotionBlur2DBase(sk.TreeClass):
Expand All @@ -1092,10 +1092,10 @@ def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None, None, None)
angle, direction = jax.lax.stop_gradient((self.angle, self.direction))
args = (image, self.kernel_size, angle, direction)
return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)

spatial_ndim: int = 2
filter_op = property(abc.abstractmethod(lambda _: ...))
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))


class MotionBlur2D(MotionBlur2DBase):
Expand All @@ -1120,7 +1120,7 @@ class MotionBlur2D(MotionBlur2DBase):
[ 6.472714 10.020969 10.770187 9.100007 ]]]
"""

filter_op = motion_blur_2d
filter_op = staticmethod(motion_blur_2d)


class FFTMotionBlur2D(MotionBlur2DBase):
Expand All @@ -1145,7 +1145,7 @@ class FFTMotionBlur2D(MotionBlur2DBase):
[ 6.472714 10.020969 10.770187 9.100007 ]]]
"""

filter_op = fft_motion_blur_2d
filter_op = staticmethod(fft_motion_blur_2d)


class MedianBlur2D(sk.TreeClass):
Expand Down Expand Up @@ -1189,10 +1189,10 @@ def __call__(self, image: CHWArray) -> CHWArray:
class Sobel2DBase(sk.TreeClass):
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(type(self).filter_op)(image)
return jax.vmap(self.filter_op)(image)

spatial_ndim: int = 2
filter_op = property(abc.abstractmethod(lambda _: ...))
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))


class Sobel2D(Sobel2DBase):
Expand All @@ -1216,7 +1216,7 @@ class Sobel2D(Sobel2DBase):
[78.24321 , 68.26419 , 72.249565, 76.23647 , 89.27486 ]]]
"""

filter_op = sobel_2d
filter_op = staticmethod(sobel_2d)


class FFTSobel2D(Sobel2DBase):
Expand All @@ -1237,7 +1237,7 @@ class FFTSobel2D(Sobel2DBase):
[78.24321 , 68.26419 , 72.249565, 76.23647 , 89.27486 ]]]
"""

filter_op = fft_sobel_2d
filter_op = staticmethod(fft_sobel_2d)


class ElasticTransform2DBase(sk.TreeClass):
Expand All @@ -1255,9 +1255,9 @@ def __init__(
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
in_axes = (None, 0, None, None, None)
args = (image, self.kernel_size, self.sigma, self.alpha)
return jax.vmap(type(self).filter_op, in_axes=in_axes)(key, *args)
return jax.vmap(self.filter_op, in_axes=in_axes)(key, *args)

filter_op = property(abc.abstractmethod(lambda _: ...))
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))
spatial_ndim: int = 2


Expand Down Expand Up @@ -1288,7 +1288,7 @@ class ElasticTransform2D(ElasticTransform2DBase):
[21. 21.659977 21.43855 21.138866 22.583244 ]]]
"""

filter_op = elastic_transform_2d
filter_op = staticmethod(elastic_transform_2d)


class FFTElasticTransform2D(ElasticTransform2DBase):
Expand Down Expand Up @@ -1318,7 +1318,7 @@ class FFTElasticTransform2D(ElasticTransform2DBase):
[21. 21.659977 21.43855 21.138866 22.583244 ]]]
"""

filter_op = fft_elastic_transform_2d
filter_op = staticmethod(fft_elastic_transform_2d)


class BilateralBlur2D(sk.TreeClass):
Expand Down Expand Up @@ -1427,10 +1427,10 @@ def __init__(
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None, None)
args = (image, self.kernel_size, self.strides)
return jax.vmap(type(self).filter_op, in_axes=in_axes)(*args)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)

spatial_ndim: int = 2
filter_op = property(abc.abstractmethod(lambda _: ...))
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))


class BlurPool2D(BlurPool2DBase):
Expand All @@ -1453,7 +1453,7 @@ class BlurPool2D(BlurPool2DBase):
[11.0625 16. 12.9375]]]
"""

filter_op = blur_pool_2d
filter_op = staticmethod(blur_pool_2d)


class FFTBlurPool2D(BlurPool2DBase):
Expand All @@ -1476,4 +1476,4 @@ class FFTBlurPool2D(BlurPool2DBase):
[11.0625 16. 12.9375]]]
"""

filter_op = fft_blur_pool_2d
filter_op = staticmethod(fft_blur_pool_2d)
Loading

0 comments on commit 749a0ee

Please sign in to comment.