From a0b587aed0ccecb794a46e2ba99713c56ed69f93 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 22:04:59 +0530 Subject: [PATCH] resolved pyline and changed the pylint version to current version of main --- .../imagenet_jax/custom_tf_addons.py | 20 ++++++++++++------- setup.cfg | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index 79aef6791..3d6939218 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -241,12 +241,15 @@ def angles_to_projective_transforms( with tf.name_scope(name or "angles_to_projective_transforms"): angle_or_angles = tf.convert_to_tensor( angles, name="angles", dtype=tf.dtypes.float32) + + if len(angle_or_angles.get_shape()) not in (0, 1): + raise ValueError("angles should have rank 0 or 1.") + if len(angle_or_angles.get_shape()) == 0: angles = angle_or_angles[None] - elif len(angle_or_angles.get_shape()) == 1: - angles = angle_or_angles else: - raise ValueError("angles should have rank 0 or 1.") + angles = angle_or_angles + cos_angles = tf.math.cos(angles) sin_angles = tf.math.sin(angles) x_offset = ((image_width - 1) - @@ -352,12 +355,15 @@ def translations_to_projective_transforms(translations: TensorLike, if translation_or_translations.get_shape().ndims is None: raise TypeError( "translation_or_translations rank must be statically known") - elif len(translation_or_translations.get_shape()) == 1: + + if len(translation_or_translations.get_shape()) not in (1, 2): + raise TypeError("Translations should have rank 1 or 2.") + + if len(translation_or_translations.get_shape()) == 1: translations = translation_or_translations[None] - elif len(translation_or_translations.get_shape()) == 2: - translations = translation_or_translations else: - raise TypeError("Translations should have rank 1 or 2.") + translations = translation_or_translations + num_translations = tf.shape(translations)[0] # The translation matrix looks like: # [[1 0 -dx] diff --git a/setup.cfg b/setup.cfg index a7c224407..7977267bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -78,7 +78,7 @@ full_dev = # Dependencies for developing the package dev = isort==5.13.2 - pylint==3.3.1 + pylint==2.16.1 pytest==8.3.3 yapf==0.32.0 pre-commit==4.0.1