Skip to content

Commit

Permalink
Use more standardized test time augmentation implementation (#99)
Browse files Browse the repository at this point in the history
* Add rotate and flip test time augmentation

Accumulate logits, with maximum confidence output being the one used to predict the output

* Add test time augmentation interfaces to CLI and lib functions

* Update docs

* Use more standard TTA implementation

* Update lock file

* Add kelp RGBI species stats

* Use more standard TTA implementation

* Update lock file
  • Loading branch information
tayden authored Jan 30, 2024
1 parent 3c33e75 commit baac68a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
32 changes: 20 additions & 12 deletions kelp_o_matic/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,27 @@ def __call__(self):
crop = torch.nn.functional.pad(
crop, (0, self.crop_size - w, 0, self.crop_size - h), value=0
)
logits = self.model(crop.unsqueeze(0))[0]

if self.tta:
for k in range(1, 4):
aug_crop = torch.rot90(crop, k=k, dims=(1, 2))
aug_logits = self.model(aug_crop.unsqueeze(0))[0]
unaug_logits = torch.rot90(aug_logits, k=-k, dims=(1, 2))
logits = torch.maximum(logits, unaug_logits)
for d in [1, 2]:
aug_crop = torch.flip(crop, dims=(d,))
aug_logits = self.model(aug_crop.unsqueeze(0))[0]
unaug_logits = torch.flip(aug_logits, dims=(d,))
logits = torch.maximum(logits, unaug_logits)
all_logits = []
for flip in [False, True]:
for k in range(4):
# Augment
aug_crop = torch.flip(crop, dims=(1,)) if flip else crop
aug_crop = torch.rot90(aug_crop, k=k, dims=(1, 2))
# Classify
aug_logits = self.model(aug_crop.unsqueeze(0))[0]
# Un-augment
aug_logits = torch.rot90(aug_logits, k=-k, dims=(1, 2))
logits = (
torch.flip(aug_logits, dims=(1,))
if flip
else aug_logits
)
all_logits.append(logits)
logits = torch.stack(all_logits).mean(dim=0)

else:
logits = self.model(crop.unsqueeze(0))[0]

logits = self.kernel(
logits,
Expand Down
14 changes: 7 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit baac68a

Please sign in to comment.