diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index f9d7917..0cbb4ef 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -14,7 +14,6 @@ RandomRotation, RandomApply, Resize, - RandomResizedCrop, ) from torchvision.transforms.v2 import RandomHorizontalFlip @@ -54,9 +53,8 @@ def rand_weak_aug(size: int): [ ToImage(), ToDtype(torch.float32, scale=True), - RandomResizedCrop( - size, scale=(0.08, 1.0), ratio=(0.95, 1.05), antialias=True - ), + Resize(size, antialias=True), + CenterCrop(size), RandomHorizontalFlip(), RandomVerticalFlip(), RandomApply([RandomRotation((90, 90))], p=0.5), @@ -80,9 +78,8 @@ def rand_strong_aug(size: int): [ ToImage(), ToDtype(torch.float32, scale=True), - RandomResizedCrop( - size, scale=(0.08, 1.0), ratio=(0.9, 1.1), antialias=True - ), + Resize(size, antialias=True), + RandomCrop(size, pad_if_needed=False), RandomHorizontalFlip(), RandomVerticalFlip(), RandomApply([RandomRotation((90, 90))], p=0.5),