From baac68a03ba8b2575bf32af2beb14cac7cccc793 Mon Sep 17 00:00:00 2001 From: Taylor Denouden Date: Tue, 30 Jan 2024 10:42:56 -0800 Subject: [PATCH] Use more standardized test time augmentation implementation (#99) * Add rotate and flip test time augmentation Accumulate logits, with maximum confidence output being the one used to predict the output * Add test time augmentation interfaces to CLI and lib functions * Update docs * Use more standard TTA implementation * Update lock file * Add kelp RGBI species stats * Use more standard TTA implementation * Update lock file --- kelp_o_matic/managers.py | 32 ++++++++++++++++++++------------ poetry.lock | 14 +++++++------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/kelp_o_matic/managers.py b/kelp_o_matic/managers.py index 5f19d98..dbb1b21 100644 --- a/kelp_o_matic/managers.py +++ b/kelp_o_matic/managers.py @@ -85,19 +85,27 @@ def __call__(self): crop = torch.nn.functional.pad( crop, (0, self.crop_size - w, 0, self.crop_size - h), value=0 ) - logits = self.model(crop.unsqueeze(0))[0] - if self.tta: - for k in range(1, 4): - aug_crop = torch.rot90(crop, k=k, dims=(1, 2)) - aug_logits = self.model(aug_crop.unsqueeze(0))[0] - unaug_logits = torch.rot90(aug_logits, k=-k, dims=(1, 2)) - logits = torch.maximum(logits, unaug_logits) - for d in [1, 2]: - aug_crop = torch.flip(crop, dims=(d,)) - aug_logits = self.model(aug_crop.unsqueeze(0))[0] - unaug_logits = torch.flip(aug_logits, dims=(d,)) - logits = torch.maximum(logits, unaug_logits) + all_logits = [] + for flip in [False, True]: + for k in range(4): + # Augment + aug_crop = torch.flip(crop, dims=(1,)) if flip else crop + aug_crop = torch.rot90(aug_crop, k=k, dims=(1, 2)) + # Classify + aug_logits = self.model(aug_crop.unsqueeze(0))[0] + # Un-augment + aug_logits = torch.rot90(aug_logits, k=-k, dims=(1, 2)) + logits = ( + torch.flip(aug_logits, dims=(1,)) + if flip + else aug_logits + ) + all_logits.append(logits) + logits = torch.stack(all_logits).mean(dim=0) + + else: + logits = self.model(crop.unsqueeze(0))[0] logits = self.kernel( logits, diff --git a/poetry.lock b/poetry.lock index 4285575..5e087f1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -634,13 +634,13 @@ mkdocs = ">=1.1" [[package]] name = "mkdocs-material" -version = "9.5.4" +version = "9.5.5" description = "Documentation that simply works" optional = false python-versions = ">=3.8" files = [ - {file = "mkdocs_material-9.5.4-py3-none-any.whl", hash = "sha256:efd7cc8ae03296d728da9bd38f4db8b07ab61f9738a0cbd0dfaf2a15a50e7343"}, - {file = "mkdocs_material-9.5.4.tar.gz", hash = "sha256:3d196ee67fad16b2df1a458d650a8ac1890294eaae368d26cee71bc24ad41c40"}, + {file = "mkdocs_material-9.5.5-py3-none-any.whl", hash = "sha256:ac50b2431a79a3b160fdefbba37c9132485f1a69166aba115ad49fafdbbbc5df"}, + {file = "mkdocs_material-9.5.5.tar.gz", hash = "sha256:4480d9580faf42fed0123d0465502bfc1c0c239ecc9c4d66159cf0459ea1b4ae"}, ] [package.dependencies] @@ -658,7 +658,7 @@ requests = ">=2.26,<3.0" [package.extras] git = ["mkdocs-git-committers-plugin-2 (>=1.1,<2.0)", "mkdocs-git-revision-date-localized-plugin (>=1.2,<2.0)"] -imaging = ["cairosvg (>=2.6,<3.0)", "pillow (>=9.4,<10.0)"] +imaging = ["cairosvg (>=2.6,<3.0)", "pillow (>=10.2,<11.0)"] recommended = ["mkdocs-minify-plugin (>=0.7,<1.0)", "mkdocs-redirects (>=1.2,<2.0)", "mkdocs-rss-plugin (>=1.6,<2.0)"] [[package]] @@ -1095,13 +1095,13 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-co [[package]] name = "pluggy" -version = "1.3.0" +version = "1.4.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" files = [ - {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, - {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, + {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, + {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, ] [package.extras]