Skip to content

Commit

Permalink
fix: resolving pylint issues in custom_tf_addons
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Dec 17, 2024
1 parent 7867711 commit d6dd2e8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"""

import math
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -48,7 +47,7 @@ def get_ndims(image):
return image.get_shape().ndims or tf.rank(image)


def to_4D_image(image):
def to_4d_image(image):
"""Convert 2/3/4D image to 4D image.
Args:
Expand All @@ -63,7 +62,7 @@ def to_4D_image(image):
]):
ndims = image.get_shape().ndims
if ndims is None:
return _dynamic_to_4D_image(image)
return _dynamic_to_4d_image(image)
elif ndims == 2:
return image[None, :, :, None]
elif ndims == 3:
Expand All @@ -72,7 +71,7 @@ def to_4D_image(image):
return image


def _dynamic_to_4D_image(image):
def _dynamic_to_4d_image(image):
shape = tf.shape(image)
original_rank = tf.rank(image)
# 4D image => [N, H, W, C] or [N, C, H, W]
Expand All @@ -91,7 +90,7 @@ def _dynamic_to_4D_image(image):
return tf.reshape(image, new_shape)


def from_4D_image(image, ndims):
def from_4d_image(image, ndims):
"""Convert back to an image with `ndims` rank.
Args:
Expand All @@ -105,7 +104,7 @@ def from_4D_image(image, ndims):
[tf.debugging.assert_rank(image, 4,
message="`image` must be 4D tensor")]):
if isinstance(ndims, tf.Tensor):
return _dynamic_from_4D_image(image, ndims)
return _dynamic_from_4d_image(image, ndims)
elif ndims == 2:
return tf.squeeze(image, [0, 3])
elif ndims == 3:
Expand All @@ -114,7 +113,7 @@ def from_4D_image(image, ndims):
return image


def _dynamic_from_4D_image(image, original_rank):
def _dynamic_from_4d_image(image, original_rank):
shape = tf.shape(image)
# 4D image <= [N, H, W, C] or [N, C, H, W]
# 3D image <= [1, H, W, C] or [1, C, H, W]
Expand Down Expand Up @@ -183,7 +182,7 @@ def transform(
transforms, name="transforms", dtype=tf.dtypes.float32)
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
images = to_4D_image(image_or_images)
images = to_4d_image(image_or_images)
original_ndims = get_ndims(image_or_images)

if output_shape is None:
Expand Down Expand Up @@ -217,7 +216,7 @@ def transform(
fill_mode=fill_mode.upper(),
fill_value=fill_value,
)
return from_4D_image(output, original_ndims)
return from_4d_image(output, original_ndims)


def angles_to_projective_transforms(
Expand Down Expand Up @@ -271,7 +270,7 @@ def angles_to_projective_transforms(
)


def rotate(
def rotate_img(
images: TensorLike,
angles: TensorLike,
interpolation: str = "nearest",
Expand All @@ -286,7 +285,7 @@ def rotate(
`(num_images, num_rows, num_columns, num_channels)`
(NHWC), `(num_rows, num_columns, num_channels)` (HWC), or
`(num_rows, num_columns)` (HW).
angles: A scalar angle to rotate all images by, or (if `images` has rank 4)
angles: A scalar angle to rotate all images by (if `images` has rank 4)
a vector of length num_images, with an angle for each image in the
batch.
interpolation: Interpolation mode. Supported values: "nearest",
Expand Down Expand Up @@ -317,7 +316,7 @@ def rotate(
image_or_images = tf.convert_to_tensor(images)
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
images = to_4D_image(image_or_images)
images = to_4d_image(image_or_images)
original_ndims = get_ndims(image_or_images)

image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None]
Expand All @@ -329,7 +328,7 @@ def rotate(
fill_mode=fill_mode,
fill_value=fill_value,
)
return from_4D_image(output, original_ndims)
return from_4d_image(output, original_ndims)


def translations_to_projective_transforms(translations: TensorLike,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import tensorflow as tf

from .custom_tf_addons import rotate
from .custom_tf_addons import rotate_img
from .custom_tf_addons import transform
from .custom_tf_addons import translate

Expand Down Expand Up @@ -179,7 +179,7 @@ def rotate(image, degrees, replace):
# In practice, we should randomize the rotation degrees by flipping
# it negatively half the time, but that's done on 'degrees' outside
# of the function.
image = rotate(wrap(image), radians)
image = rotate_img(wrap(image), radians)
return unwrap(image, replace)


Expand Down

0 comments on commit d6dd2e8

Please sign in to comment.