Skip to content

Commit

Permalink
Add "sample" argument back for transforms.
Browse files Browse the repository at this point in the history
  • Loading branch information
jsilter committed Mar 11, 2024
1 parent d0da448 commit 6b62545
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
16 changes: 9 additions & 7 deletions sybil/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def __init__(self, augmentations):
super(ComposeAug, self).__init__()
self.augmentations = augmentations

def __call__(self, input_dict):
def __call__(self, input_dict, sample=None):
for transformer in self.augmentations:
input_dict = transformer(input_dict)
input_dict = transformer(input_dict, sample)

return input_dict

Expand All @@ -97,7 +97,7 @@ def __init__(self):
self.transform = ToTensorV2()
self.name = "totensor"

def __call__(self, input_dict):
def __call__(self, input_dict, sample=None):
input_dict["input"] = torch.from_numpy(input_dict["input"]).float()
if input_dict.get("mask", None) is not None:
input_dict["mask"] = torch.from_numpy(input_dict["mask"]).float()
Expand All @@ -117,7 +117,7 @@ def __init__(self, args, kwargs):
self.set_cachable(width, height)
self.transform = A.Resize(height, width)

def __call__(self, input_dict):
def __call__(self, input_dict, sample=None):
out = self.transform(
image=input_dict["input"], mask=input_dict.get("mask", None)
)
Expand All @@ -140,7 +140,9 @@ def __init__(self, args, kwargs):
self.max_angle = int(kwargs["deg"])
self.transform = A.Rotate(limit=self.max_angle, p=0.5)

def __call__(self, input_dict):
def __call__(self, input_dict, sample=None):
if sample and "seed" in sample:
self.set_seed(sample["seed"])
out = self.transform(
image=input_dict["input"], mask=input_dict.get("mask", None)
)
Expand Down Expand Up @@ -169,7 +171,7 @@ def __init__(self, args, kwargs):
"png",
]

def __call__(self, input_dict):
def __call__(self, input_dict, sample=None):
img = input_dict["input"]
if len(img.size()) == 2:
img = img.unsqueeze(0)
Expand All @@ -193,7 +195,7 @@ def __init__(self, args, kwargs):
assert len(kwargs) == 0
self.args = args

def __call__(self, input_dict):
def __call__(self, input_dict, sample=None):
img = input_dict["input"]
mask = input_dict.get("mask", None)
if mask is not None:
Expand Down
14 changes: 8 additions & 6 deletions sybil/loaders/abstract_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def split_augmentations_by_cache(augmentations):


def apply_augmentations_and_cache(
loaded_input, img_path, augmentations, cache, base_key=""
loaded_input, sample, img_path, augmentations, cache, base_key=""
):
"""
Loads the loaded input by its absolute path and apply the augmentations one
Expand All @@ -69,7 +69,7 @@ def apply_augmentations_and_cache(
all_prev_cachable = True
key = base_key
for ind, trans in enumerate(augmentations):
loaded_input = trans(loaded_input)
loaded_input = trans(loaded_input, sample)
if not all_prev_cachable or not trans.cachable():
all_prev_cachable = False
else:
Expand Down Expand Up @@ -153,17 +153,17 @@ def load_input(self, path):
def cached_extension(self):
pass

def configure_path(self, path):
def configure_path(self, path, sample=None):
return path

def get_image(self, path):
def get_image(self, path, sample=None):
"""
Returns a transformed image by its absolute path.
If cache is used - transformed image will be loaded if available,
and saved to cache if not.
"""
input_dict = {}
input_path = self.configure_path(path)
input_path = self.configure_path(path, sample)

if input_path == self.pad_token:
return self.load_input(input_path)
Expand All @@ -172,7 +172,7 @@ def get_image(self, path):
input_dict = self.load_input(input_path)
# hidden loaders typically do not use augmentation
if self.apply_augmentations:
input_dict = self.composed_all_augmentations(input_dict)
input_dict = self.composed_all_augmentations(input_dict, sample)
return input_dict

if self.args.use_annotations:
Expand All @@ -192,6 +192,7 @@ def get_image(self, path):
if self.apply_augmentations:
input_dict = apply_augmentations_and_cache(
input_dict,
sample,
input_path,
post_augmentations,
self.cache,
Expand All @@ -210,6 +211,7 @@ def get_image(self, path):
if self.apply_augmentations:
input_dict = apply_augmentations_and_cache(
input_dict,
sample,
input_path,
all_augmentations,
self.cache,
Expand Down

0 comments on commit 6b62545

Please sign in to comment.