Skip to content

Commit

Permalink
Merge pull request #38 from ZerenLong/fix_bug_in_augmentation
Browse files Browse the repository at this point in the history
[Bug] Fix bug in augmentations.py
  • Loading branch information
hurjunhwa authored May 28, 2021
2 parents 9103839 + 8bec0ea commit 4d7f6aa
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def apply_random_transforms_to_params(self,

theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1 - invalid).float() * thetas
thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas

# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
Expand Down Expand Up @@ -796,7 +796,7 @@ def apply_random_transforms_to_params(self,

theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1 - invalid).float() * thetas
thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas

# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
Expand Down Expand Up @@ -1075,7 +1075,7 @@ def apply_random_transforms_to_params(self,

theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)
theta_try = apply_transform_to_params(theta0, theta_transform)
thetas = invalid.float() * theta_try + (1 - invalid).float() * thetas
thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas

# compute new invalid ones
invalid = self.find_invalid(width=width, height=height, thetas=thetas)
Expand Down

0 comments on commit 4d7f6aa

Please sign in to comment.