From 0c11e6ae9cd947d0a0de77becf8544e55e55b8fc Mon Sep 17 00:00:00 2001 From: Evening Date: Wed, 5 Jun 2024 17:22:55 +0800 Subject: [PATCH 01/16] Make Partial saving function more generic --- src/frdc/models/efficientnetb1.py | 30 +------------- src/frdc/models/utils.py | 68 ++++++++++++++++++++----------- src/frdc/train/fixmatch_module.py | 10 ++++- src/frdc/train/mixmatch_module.py | 13 +++++- 4 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/frdc/models/efficientnetb1.py b/src/frdc/models/efficientnetb1.py index 28276aa..5d5b4ef 100644 --- a/src/frdc/models/efficientnetb1.py +++ b/src/frdc/models/efficientnetb1.py @@ -10,7 +10,7 @@ EfficientNet_B1_Weights, ) -from frdc.models.utils import on_save_checkpoint, on_load_checkpoint +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.fixmatch_module import FixMatchModule from frdc.train.mixmatch_module import MixMatchModule from frdc.utils.ema import EMA @@ -139,7 +139,6 @@ def update_ema(self): self.ema_updater.update(self.ema_lr) def forward(self, x: torch.Tensor): - """Forward pass.""" return self.fc(self.eff(x)) def configure_optimizers(self): @@ -147,21 +146,6 @@ def configure_optimizers(self): self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # TODO: MixMatch's saving is a bit complicated due to the dependency - # on the EMA model. This only saves the FC for both the - # main model and the EMA model. - # This may be the reason certain things break when loading - if checkpoint["hyper_parameters"]["frozen"]: - on_save_checkpoint( - self, - checkpoint, - saved_module_prefixes=("_ema_model.fc.", "fc."), - ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - on_load_checkpoint(self, checkpoint) - class EfficientNetB1FixMatchModule(FixMatchModule): MIN_SIZE = 255 @@ -209,21 +193,9 @@ def __init__( ) def forward(self, x: torch.Tensor): - """Forward pass.""" return self.fc(self.eff(x)) def configure_optimizers(self): return torch.optim.Adam( self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if checkpoint["hyper_parameters"]["frozen"]: - on_save_checkpoint( - self, - checkpoint, - saved_module_prefixes=("fc.",), - ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - on_load_checkpoint(self, checkpoint) diff --git a/src/frdc/models/utils.py b/src/frdc/models/utils.py index 2b75e1c..a3b94b3 100644 --- a/src/frdc/models/utils.py +++ b/src/frdc/models/utils.py @@ -1,15 +1,12 @@ -from typing import Dict, Any, Sequence +import logging +import warnings +from typing import Dict, Any, Sequence, Callable -def on_save_checkpoint( +def save_unfrozen( self, checkpoint: Dict[str, Any], - saved_module_prefixes: Sequence[str] = ("fc.",), - saved_module_suffixes: Sequence[str] = ( - "running_mean", - "running_var", - "num_batches_tracked", - ), + include_also: Callable[[str], bool] = lambda k: False, ) -> None: """Saving only the classifier if frozen. @@ -20,29 +17,54 @@ def on_save_checkpoint( This usually reduces the model size by 99.9%, so it's worth it. - By default, this will save the classifier and the BatchNorm running - statistics. + By default, this will save any parameter that requires grad + and the BatchNorm running statistics. Args: self: Not used, but kept for consistency with on_load_checkpoint. checkpoint: The checkpoint to save. - saved_module_prefixes: The prefixes of the modules to save. - saved_module_suffixes: The suffixes of the modules to save. + include_also: A function that returns whether to include a parameter, + on top of any parameter that requires grad and BatchNorm running + statistics. """ - if checkpoint["hyper_parameters"]["frozen"]: - # Keep only the classifier - checkpoint["state_dict"] = { - k: v - for k, v in checkpoint["state_dict"].items() - if ( - k.startswith(saved_module_prefixes) - or k.endswith(saved_module_suffixes) - ) - } + # Keep only the classifier + new_state_dict = {} + for k, v in checkpoint["state_dict"].items(): + # We keep 2 things, + # 1. The BatchNorm running statistics + # 2. Anything that requires grad -def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # BatchNorm running statistics should be kept + # for closer reconstruction of the model + is_bn_var = k.endswith( + ("running_mean", "running_var", "num_batches_tracked") + ) + try: + # We need to retrieve it from the original model + # as the state dict already freezes the model + is_required_grad = self.get_parameter(k).requires_grad + except AttributeError: + if not is_bn_var: + warnings.warn( + f"Unknown non-parameter key in state_dict. {k}." + f"This is an edge case where it's not a parameter nor " + f"BatchNorm running statistics. This will still be saved." + ) + is_required_grad = True + + # These are additional parameters to keep + is_include = include_also(k) + + if is_required_grad or is_bn_var or is_include: + logging.debug(f"Keeping {k}") + new_state_dict[k] = v + + checkpoint["state_dict"] = new_state_dict + + +def load_checkpoint_lenient(self, checkpoint: Dict[str, Any]) -> None: """Loading only the classifier if frozen Notes: diff --git a/src/frdc/train/fixmatch_module.py b/src/frdc/train/fixmatch_module.py index 67e788d..0d8cd80 100644 --- a/src/frdc/train/fixmatch_module.py +++ b/src/frdc/train/fixmatch_module.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Any, Dict import torch import torch.nn.functional as F @@ -10,6 +10,7 @@ from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torchmetrics.functional import accuracy +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.utils import ( wandb_hist, preprocess, @@ -92,6 +93,7 @@ def training_step(self, batch, batch_idx): ℓ Loss: ℓ_lbl + ℓ_unl """ + def training_step(self, batch, batch_idx): (x_lbl, y_lbl), x_unls = batch opt = self.optimizers() @@ -246,3 +248,9 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: y_encoder=self.y_encoder, x_unl=x_unl, ) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_unfrozen(self, checkpoint) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 300d25e..fb24b9b 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Any, Dict import torch import torch.nn.functional as F @@ -11,6 +11,7 @@ from torch.nn.functional import one_hot from torchmetrics.functional import accuracy +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.utils import ( mix_up, sharpen, @@ -260,3 +261,13 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: y_encoder=self.y_encoder, x_unl=x_unl, ) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_unfrozen( + self, + checkpoint, + include_also=lambda k: k.startswith("_ema_model."), + ) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + load_checkpoint_lenient(self, checkpoint) From ce8f7bb1d687cd6762e071cea21c5e281ac1ce46 Mon Sep 17 00:00:00 2001 From: Evening Date: Wed, 5 Jun 2024 17:26:14 +0800 Subject: [PATCH 02/16] Reduce scope of MixMatch custom criteria saving --- src/frdc/train/mixmatch_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index fb24b9b..9617751 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -266,7 +266,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: save_unfrozen( self, checkpoint, - include_also=lambda k: k.startswith("_ema_model."), + include_also=lambda k: k.startswith("_ema_model.fc."), ) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: From cf1fae02b5ef92bd8c0d58625ab85c8f257233c3 Mon Sep 17 00:00:00 2001 From: Evening Date: Thu, 6 Jun 2024 23:12:32 +0800 Subject: [PATCH 03/16] Fix flaky retrieval of task --- src/frdc/load/label_studio.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/src/frdc/load/label_studio.py b/src/frdc/load/label_studio.py index 5486e92..8fd5ec2 100644 --- a/src/frdc/load/label_studio.py +++ b/src/frdc/load/label_studio.py @@ -16,8 +16,6 @@ def get_bounds_and_labels(self) -> tuple[list[tuple[int, int]], list[str]]: bounds = [] labels = [] - # for ann_ix, ann in enumerate(self["annotations"]): - ann = self["annotations"][0] results = ann["result"] for r_ix, r in enumerate(results): @@ -60,24 +58,15 @@ def get_task( project_id: int = 1, ): proj = LABEL_STUDIO_CLIENT.get_project(project_id) - # Get the task that has the file name - filter = Filters.create( - Filters.AND, - [ - Filters.item( - # The GS path is in the image column, so we can just filter on that - Column.data("image"), - Operator.CONTAINS, - Type.String, - Path(file_name).as_posix(), - ) - ], - ) - tasks = proj.get_tasks(filter) - - if len(tasks) > 1: + task_ids = [ + task["id"] + for task in proj.get_tasks() + if file_name.as_posix() in task["storage_filename"] + ] + + if len(task_ids) > 1: warn(f"More than 1 task found for {file_name}, using the first one") - elif len(tasks) == 0: + elif len(task_ids) == 0: raise ValueError(f"No task found for {file_name}") - return Task(tasks[0]) + return Task(proj.get_task(task_ids[0])) From 02a4d57e794ada0bb347290fa97ae0e52ee13c52 Mon Sep 17 00:00:00 2001 From: Evening Date: Fri, 7 Jun 2024 22:19:43 +0800 Subject: [PATCH 04/16] Migrate x as preprocessing step in dataset --- src/frdc/load/dataset.py | 36 ++++++++--- src/frdc/load/preset.py | 57 ++++++++++++++---- src/frdc/models/efficientnetb1.py | 6 -- src/frdc/train/fixmatch_module.py | 4 -- src/frdc/train/mixmatch_module.py | 3 - src/frdc/train/utils.py | 59 ++----------------- .../chestnut_dec_may/train_fixmatch.py | 3 - tests/model_tests/utils.py | 8 --- 8 files changed, 76 insertions(+), 100 deletions(-) diff --git a/src/frdc/load/dataset.py b/src/frdc/load/dataset.py index e0a0d9c..9ab6cb6 100644 --- a/src/frdc/load/dataset.py +++ b/src/frdc/load/dataset.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd from PIL import Image +from sklearn.preprocessing import StandardScaler from torch.utils.data import Dataset, ConcatDataset from frdc.conf import ( @@ -71,8 +72,9 @@ def __init__( site: str, date: str, version: str | None, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool = False, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, @@ -95,6 +97,7 @@ def __init__( date: The date of the dataset, e.g. "20201218". version: The version of the dataset, e.g. "183deg". transform: The transform to apply to each segment. + transform_scale: Prepends a scaling transform to the transform. target_transform: The transform to apply to each label. use_legacy_bounds: Whether to use the legacy bounds.csv file. This will automatically be set to True if LABEL_STUDIO_CLIENT @@ -102,6 +105,7 @@ def __init__( to. polycrop: Whether to further crop the segments via its polygon bounds. The cropped area will be padded with np.nan. + polycrop_value: The value to pad the cropped area with. """ self.site = site self.date = date @@ -125,17 +129,33 @@ def __init__( self.transform = transform self.target_transform = target_transform + if transform_scale: + self.x_scaler = StandardScaler() + self.x_scaler.fit( + np.concatenate( + [ + # Segments: [H x W x C] -> [H*W, C] + # Reshaping is necessary for StandardScaler + segm.reshape(-1, segm.shape[-1]) + for segm in self.ar_segments + ] + ) + ) + self.transform = lambda x: transform( + self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( + x.shape + ) + ) + else: + self.x_scaler = None + def __len__(self): return len(self.ar_segments) def __getitem__(self, idx): return ( - self.transform(self.ar_segments[idx]) - if self.transform - else self.ar_segments[idx], - self.target_transform(self.targets[idx]) - if self.target_transform - else self.targets[idx], + self.transform(self.ar_segments[idx]), + self.target_transform(self.targets[idx]), ) @property diff --git a/src/frdc/load/preset.py b/src/frdc/load/preset.py index da0aedc..4d09c88 100644 --- a/src/frdc/load/preset.py +++ b/src/frdc/load/preset.py @@ -47,15 +47,26 @@ class FRDCDatasetPartial: def __call__( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, ): - """Alias for labelled().""" + """Alias for labelled(). + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ return self.labelled( transform, + transform_scale, target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, @@ -64,19 +75,30 @@ def __call__( def labelled( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, ): - """Returns the Labelled Dataset.""" + """Returns the Labelled Dataset. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ return FRDCDataset( self.site, self.date, self.version, - transform, - target_transform, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, polycrop_value=polycrop_value, @@ -84,8 +106,9 @@ def labelled( def unlabelled( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, @@ -96,13 +119,22 @@ def unlabelled( This simply masks away the labels during __getitem__. The same behaviour can be achieved by setting __class__ to FRDCUnlabelledDataset, but this is a more convenient way to do so. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. """ return FRDCUnlabelledDataset( self.site, self.date, self.version, - transform, - target_transform, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, polycrop_value=polycrop_value, @@ -167,5 +199,4 @@ class FRDCDatasetPreset: Resize((resize, resize)), ] ), - target_transform=None, ) diff --git a/src/frdc/models/efficientnetb1.py b/src/frdc/models/efficientnetb1.py index 5d5b4ef..d793249 100644 --- a/src/frdc/models/efficientnetb1.py +++ b/src/frdc/models/efficientnetb1.py @@ -83,7 +83,6 @@ def __init__( in_channels: int, n_classes: int, lr: float, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, ema_lr: float = 0.001, weight_decay: float = 1e-5, @@ -95,7 +94,6 @@ def __init__( in_channels: The number of input channels. n_classes: The number of classes. lr: The learning rate. - x_scaler: The X input StandardScaler. y_encoder: The Y input OrdinalEncoder. ema_lr: The learning rate for the EMA model. weight_decay: The weight decay. @@ -109,7 +107,6 @@ def __init__( super().__init__( n_classes=n_classes, - x_scaler=x_scaler, y_encoder=y_encoder, sharpen_temp=0.5, mix_beta_alpha=0.75, @@ -157,7 +154,6 @@ def __init__( in_channels: int, n_classes: int, lr: float, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, weight_decay: float = 1e-5, frozen: bool = True, @@ -168,7 +164,6 @@ def __init__( in_channels: The number of input channels. n_classes: The number of classes. lr: The learning rate. - x_scaler: The X input StandardScaler. y_encoder: The Y input OrdinalEncoder. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -181,7 +176,6 @@ def __init__( super().__init__( n_classes=n_classes, - x_scaler=x_scaler, y_encoder=y_encoder, ) diff --git a/src/frdc/train/fixmatch_module.py b/src/frdc/train/fixmatch_module.py index 0d8cd80..93e8c9a 100644 --- a/src/frdc/train/fixmatch_module.py +++ b/src/frdc/train/fixmatch_module.py @@ -21,7 +21,6 @@ class FixMatchModule(LightningModule): def __init__( self, *, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, n_classes: int = 10, unl_conf_threshold: float = 0.95, @@ -41,7 +40,6 @@ def __init__( Args: n_classes: The number of classes in the dataset. - x_scaler: The StandardScaler to use for the data. y_encoder: The OrdinalEncoder to use for the labels. unl_conf_threshold: The confidence threshold for unlabelled data to be considered correctly labelled. @@ -49,7 +47,6 @@ def __init__( super().__init__() - self.x_scaler = x_scaler self.y_encoder = y_encoder self.n_classes = n_classes self.unl_conf_threshold = unl_conf_threshold @@ -244,7 +241,6 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: return preprocess( x_lbl=x_lbl, y_lbl=y_lbl, - x_scaler=self.x_scaler, y_encoder=self.y_encoder, x_unl=x_unl, ) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 9617751..cf6442a 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -24,7 +24,6 @@ class MixMatchModule(LightningModule): def __init__( self, *, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, n_classes: int = 10, sharpen_temp: float = 0.5, @@ -52,7 +51,6 @@ def __init__( super().__init__() - self.x_scaler = x_scaler self.y_encoder = y_encoder self.n_classes = n_classes self.sharpen_temp = sharpen_temp @@ -257,7 +255,6 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: return preprocess( x_lbl=x_lbl, y_lbl=y_lbl, - x_scaler=self.x_scaler, y_encoder=self.y_encoder, x_unl=x_unl, ) diff --git a/src/frdc/train/utils.py b/src/frdc/train/utils.py index 6e85fe7..bb1b2e4 100644 --- a/src/frdc/train/utils.py +++ b/src/frdc/train/utils.py @@ -55,7 +55,6 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: def preprocess( x_lbl: torch.Tensor, y_lbl: torch.Tensor, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, x_unl: list[torch.Tensor] = None, nan_mask: bool = True, @@ -72,7 +71,6 @@ def preprocess( Args: x_lbl: The data to preprocess. y_lbl: The labels to preprocess. - x_scaler: The StandardScaler to use. y_encoder: The OrdinalEncoder to use. x_unl: The unlabelled data to preprocess. nan_mask: Whether to remove nan values from the batch. @@ -83,13 +81,8 @@ def preprocess( x_unl = [] if x_unl is None else x_unl - x_lbl_trans = x_standard_scale(x_scaler, x_lbl) - y_trans = y_encode(y_encoder, y_lbl) - x_unl_trans = fn_recursive( - x_unl, - fn=lambda x: x_standard_scale(x_scaler, x), - type_atom=torch.Tensor, - type_list=list, + y_trans = torch.from_numpy( + y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] ) # Remove nan values from the batch @@ -98,10 +91,10 @@ def preprocess( nan = ( ~torch.isnan(y_trans) if nan_mask else torch.ones_like(y_trans).bool() ) - x_lbl_trans = x_lbl_trans[nan] + x_lbl_trans = x_lbl[nan] x_lbl_trans = torch.nan_to_num(x_lbl_trans) x_unl_trans = fn_recursive( - x_unl_trans, + x_unl, fn=lambda x: torch.nan_to_num(x[nan]), type_atom=torch.Tensor, type_list=list, @@ -111,50 +104,6 @@ def preprocess( return (x_lbl_trans, y_trans.long()), x_unl_trans -def x_standard_scale( - x_scaler: StandardScaler, x: torch.Tensor -) -> torch.Tensor: - """Standard scales the data - - Notes: - This is a wrapper around the StandardScaler to handle PyTorch tensors. - - Args: - x_scaler: The StandardScaler to use. - x: The data to standard scale, of shape (B, C, H, W). - """ - # Standard Scaler only accepts (n_samples, n_features), - # so we need to do some fancy reshaping. - # Note that moving dimensions then reshaping is different from just - # reshaping! - - # Move Channel to the last dimension then transform - # B x C x H x W -> B x H x W x C - b, c, h, w = x.shape - x_ss = x_scaler.transform(x.permute(0, 2, 3, 1).reshape(-1, c)) - - # Move Channel back to the second dimension - # B x H x W x C -> B x C x H x W - return torch.nan_to_num( - torch.from_numpy(x_ss.reshape(b, h, w, c)).permute(0, 3, 1, 2).float() - ) - - -def y_encode(y_encoder: OrdinalEncoder, y: torch.Tensor) -> torch.Tensor: - """Encodes the labels - - Notes: - This is a wrapper around the OrdinalEncoder to handle PyTorch tensors. - - Args: - y_encoder: The OrdinalEncoder to use. - y: The labels to encode. - """ - return torch.from_numpy( - y_encoder.transform(np.array(y).reshape(-1, 1))[..., 0] - ) - - def wandb_hist(x: torch.Tensor, num_bins: int) -> wandb.Histogram: """Records a W&B Histogram""" return wandb.Histogram( diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index b1df238..7975883 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -27,7 +27,6 @@ FRDCDatasetStaticEval, n_weak_strong_aug, get_y_encoder, - get_x_scaler, weak_aug, ) @@ -96,13 +95,11 @@ def main( ) oe = get_y_encoder(train_lab_ds.targets) - ss = get_x_scaler(train_lab_ds.ar_segments) m = EfficientNetB1FixMatchModule( in_channels=train_lab_ds.ar.shape[-1], n_classes=len(oe.categories_[0]), lr=lr, - x_scaler=ss, y_encoder=oe, frozen=True, ) diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index ac5b54a..e4816b3 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -132,11 +132,3 @@ def get_y_encoder(targets): ) oe.fit(np.array(targets).reshape(-1, 1)) return oe - - -def get_x_scaler(segments): - ss = StandardScaler() - ss.fit( - np.concatenate([segm.reshape(-1, segm.shape[-1]) for segm in segments]) - ) - return ss From 8497c3e72307c1827056a0963d0abbb42f2433ba Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 10 Jun 2024 11:23:35 +0800 Subject: [PATCH 05/16] Fix unexpected additional x scaler arg --- tests/integration_tests/test_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration_tests/test_pipeline.py b/tests/integration_tests/test_pipeline.py index 97676dc..9db094c 100644 --- a/tests/integration_tests/test_pipeline.py +++ b/tests/integration_tests/test_pipeline.py @@ -48,7 +48,6 @@ def test_manual_segmentation_pipeline(model_fn, ds): in_channels=ds.ar.shape[-1], n_classes=n_classes, lr=1e-3, - x_scaler=ss, y_encoder=oe, frozen=True, ) From ab9de7ca2e5f8601b32936ff868494a15a91b50e Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 10 Jun 2024 12:40:21 +0800 Subject: [PATCH 06/16] Refactor common modules to FRDCModule --- src/frdc/models/efficientnetb1.py | 28 ++-- src/frdc/train/fixmatch_module.py | 81 ++++------ src/frdc/train/frdc_module.py | 143 ++++++++++++++++++ src/frdc/train/mixmatch_module.py | 73 ++++----- src/frdc/train/utils.py | 52 ------- tests/integration_tests/test_pipeline.py | 18 +-- .../chestnut_dec_may/train_fixmatch.py | 7 +- .../chestnut_dec_may/train_mixmatch.py | 11 +- 8 files changed, 212 insertions(+), 201 deletions(-) create mode 100644 src/frdc/train/frdc_module.py diff --git a/src/frdc/models/efficientnetb1.py b/src/frdc/models/efficientnetb1.py index d793249..ae87d8d 100644 --- a/src/frdc/models/efficientnetb1.py +++ b/src/frdc/models/efficientnetb1.py @@ -1,8 +1,7 @@ from copy import deepcopy -from typing import Dict, Any +from typing import Sequence import torch -from sklearn.preprocessing import OrdinalEncoder, StandardScaler from torch import nn from torchvision.models import ( EfficientNet, @@ -10,7 +9,6 @@ EfficientNet_B1_Weights, ) -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.fixmatch_module import FixMatchModule from frdc.train.mixmatch_module import MixMatchModule from frdc.utils.ema import EMA @@ -81,9 +79,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - y_encoder: OrdinalEncoder, ema_lr: float = 0.001, weight_decay: float = 1e-5, frozen: bool = True, @@ -92,9 +89,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - y_encoder: The Y input OrdinalEncoder. ema_lr: The learning rate for the EMA model. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -106,15 +102,14 @@ def __init__( self.weight_decay = weight_decay super().__init__( - n_classes=n_classes, - y_encoder=y_encoder, + out_targets=out_targets, sharpen_temp=0.5, mix_beta_alpha=0.75, ) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) @@ -152,9 +147,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - y_encoder: OrdinalEncoder, weight_decay: float = 1e-5, frozen: bool = True, ): @@ -162,9 +156,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - y_encoder: The Y input OrdinalEncoder. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -174,15 +167,12 @@ def __init__( self.lr = lr self.weight_decay = weight_decay - super().__init__( - n_classes=n_classes, - y_encoder=y_encoder, - ) + super().__init__(out_targets=out_targets) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) diff --git a/src/frdc/train/fixmatch_module.py b/src/frdc/train/fixmatch_module.py index 93e8c9a..f29316c 100644 --- a/src/frdc/train/fixmatch_module.py +++ b/src/frdc/train/fixmatch_module.py @@ -1,28 +1,23 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict +from typing import Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder +from lightning.pytorch.utilities.types import STEP_OUTPUT from torchmetrics.functional import accuracy -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient -from frdc.train.utils import ( - wandb_hist, - preprocess, -) +from frdc.train.frdc_module import FRDCModule +from frdc.train.utils import wandb_hist -class FixMatchModule(LightningModule): +class FixMatchModule(FRDCModule): def __init__( self, *, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], unl_conf_threshold: float = 0.95, ): """PyTorch Lightning Module for MixMatch @@ -39,16 +34,13 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. - y_encoder: The OrdinalEncoder to use for the labels. + out_targets: The output targets for the model. unl_conf_threshold: The confidence threshold for unlabelled data to be considered correctly labelled. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.y_encoder = y_encoder - self.n_classes = n_classes self.unl_conf_threshold = unl_conf_threshold self.save_hyperparameters() @@ -60,7 +52,11 @@ def __init__( def forward(self, x): ... - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): """A single training step for a batch Notes: @@ -91,7 +87,6 @@ def training_step(self, batch, batch_idx): Loss: ℓ_lbl + ℓ_unl """ - def training_step(self, batch, batch_idx): (x_lbl, y_lbl), x_unls = batch opt = self.optimizers() @@ -172,7 +167,11 @@ def training_step(self, batch, batch_idx): } ) - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) @@ -194,7 +193,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ) -> STEP_OUTPUT: # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch y_pred = self(x) @@ -207,7 +210,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ): (x, y), _x_unls = batch y_pred = self(x) y_true_str = self.y_encoder.inverse_transform( @@ -217,36 +223,3 @@ def predict_step(self, batch, *args, **kwargs) -> Any: y_pred.argmax(dim=1).cpu().numpy().reshape(-1, 1) ) return y_true_str, y_pred_str - - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - y_encoder=self.y_encoder, - x_unl=x_unl, - ) - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - save_unfrozen(self, checkpoint) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/frdc_module.py b/src/frdc/train/frdc_module.py new file mode 100644 index 0000000..22582e9 --- /dev/null +++ b/src/frdc/train/frdc_module.py @@ -0,0 +1,143 @@ +from typing import Any, Dict, Sequence + +import numpy as np +import torch +from lightning import LightningModule +from sklearn.preprocessing import OrdinalEncoder + +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient +from frdc.utils.utils import fn_recursive + + +class FRDCModule(LightningModule): + def __init__( + self, + *, + out_targets: Sequence[str], + nan_mask_missing_y_labels: bool = True, + ): + """Base Lightning Module for MixMatch + + Notes: + This is the base class for MixMatch and FixMatch. + This implements the Y-Encoder logic so that all modules can + encode and decode the tree string labels. + + Generally the hierarchy is: + Module + -> Module + -> FRDCModule + + E.g. + EfficientNetB1MixMatchModule + -> MixMatchModule + -> FRDCModule + + WideResNetFixMatchModule + -> FixMatchModule + -> FRDCModule + + Args: + out_targets: The output targets for the model. + nan_mask_missing_y_labels: Whether to mask away x values that + have missing y labels. This happens when the y label is not + present in the OrdinalEncoder's categories, which happens + during non-training steps. E.g. A new unseen tree is inferred. + """ + + super().__init__() + + self.y_encoder: OrdinalEncoder = OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=np.nan, + ) + self.y_encoder.fit(np.array(out_targets).reshape(-1, 1)) + self.nan_mask_missing_y_labels = nan_mask_missing_y_labels + self.save_hyperparameters() + + @property + def n_classes(self): + return len(self.y_encoder.categories_[0]) + + @torch.no_grad() + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + """This method is called before any data transfer to the device. + + Notes: + This method wraps OrdinalEncoder to convert labels from str to int + before transferring the data to the device. + + Note that this step must happen before the transfer as tensors + don't support str types. + + PyTorch Lightning may complain about this being on the Module + instead of the DataModule. However, this is intentional as we + want to export the model alongside the transformations. + """ + + if self.training: + (x_lbl, y_lbl), x_unl = batch + else: + x_lbl, y_lbl = batch + x_unl = [] + + y_trans = torch.from_numpy( + self.y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] + ) + + # Remove nan values from the batch + # Ordinal Encoders can return a np.nan if the value is not in the + # categories. We will remove that from the batch. + nan = ( + ~torch.isnan(y_trans) # Keeps all non-nan values + if self.nan_mask_missing_y_labels + else torch.ones_like(y_trans).bool() # Keeps all values + ) + + x_lbl_trans = torch.nan_to_num(x_lbl[nan]) + + # This function applies nan_to_num to all tensors in the list, + # regardless of how deeply nested they are. + x_unl_trans = fn_recursive( + x_unl, + fn=lambda x: torch.nan_to_num(x[nan]), + type_atom=torch.Tensor, + type_list=list, + ) + y_trans = y_trans[nan].long() + + return (x_lbl_trans, y_trans), x_unl_trans + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_unfrozen(self, checkpoint) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + load_checkpoint_lenient(self, checkpoint) + + # The following methods are to enforce the batch schema typing. + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: + ... diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index cf6442a..3b85785 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -1,31 +1,28 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torch.nn.functional import one_hot from torchmetrics.functional import accuracy -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient +from frdc.models.utils import save_unfrozen +from frdc.train.frdc_module import FRDCModule from frdc.train.utils import ( mix_up, sharpen, wandb_hist, - preprocess, ) -class MixMatchModule(LightningModule): +class MixMatchModule(FRDCModule): def __init__( self, *, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], sharpen_temp: float = 0.5, mix_beta_alpha: float = 0.75, ): @@ -43,16 +40,14 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. + out_targets: The output targets for the model. sharpen_temp: The temperature to use for sharpening. mix_beta_alpha: The alpha to use for the beta distribution when mixing. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.y_encoder = y_encoder - self.n_classes = n_classes self.sharpen_temp = sharpen_temp self.mix_beta_alpha = mix_beta_alpha self.save_hyperparameters() @@ -112,7 +107,11 @@ def progress(self): self.global_step / self.trainer.num_training_batches ) / self.trainer.max_epochs - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x_lbl, y_lbl), x_unls = batch self.log("train/x_lbl_mean", x_lbl.mean()) @@ -189,7 +188,11 @@ def training_step(self, batch, batch_idx): def on_after_backward(self) -> None: self.update_ema() - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) y_pred = self.ema_model(x) @@ -209,7 +212,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch y_pred = self.ema_model(x) loss = F.cross_entropy(y_pred, y.long()) @@ -221,7 +228,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: (x, y), _x_unls = batch y_pred = self.ema_model(x) y_true_str = self.y_encoder.inverse_transform( @@ -232,39 +242,10 @@ def predict_step(self, batch, *args, **kwargs) -> Any: ) return y_true_str, y_pred_str - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - y_encoder=self.y_encoder, - x_unl=x_unl, - ) - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """This override the original method to save the EMAs as well.""" save_unfrozen( self, checkpoint, include_also=lambda k: k.startswith("_ema_model.fc."), ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/utils.py b/src/frdc/train/utils.py index bb1b2e4..b42ce91 100644 --- a/src/frdc/train/utils.py +++ b/src/frdc/train/utils.py @@ -52,58 +52,6 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: return y_sharp -def preprocess( - x_lbl: torch.Tensor, - y_lbl: torch.Tensor, - y_encoder: OrdinalEncoder, - x_unl: list[torch.Tensor] = None, - nan_mask: bool = True, -) -> tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]]: - """Preprocesses the data - - Notes: - The reason why x and y's preprocessing is coupled is due to the NaN - elimination step. The NaN elimination step is due to unseen labels by y - - fn_recursive is to recursively apply some function to a nested list. - This happens due to unlabelled being a list of tensors. - - Args: - x_lbl: The data to preprocess. - y_lbl: The labels to preprocess. - y_encoder: The OrdinalEncoder to use. - x_unl: The unlabelled data to preprocess. - nan_mask: Whether to remove nan values from the batch. - - Returns: - The preprocessed data and labels. - """ - - x_unl = [] if x_unl is None else x_unl - - y_trans = torch.from_numpy( - y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] - ) - - # Remove nan values from the batch - # Ordinal Encoders can return a np.nan if the value is not in the - # categories. We will remove that from the batch. - nan = ( - ~torch.isnan(y_trans) if nan_mask else torch.ones_like(y_trans).bool() - ) - x_lbl_trans = x_lbl[nan] - x_lbl_trans = torch.nan_to_num(x_lbl_trans) - x_unl_trans = fn_recursive( - x_unl, - fn=lambda x: torch.nan_to_num(x[nan]), - type_atom=torch.Tensor, - type_list=list, - ) - y_trans = y_trans[nan] - - return (x_lbl_trans, y_trans.long()), x_unl_trans - - def wandb_hist(x: torch.Tensor, num_bins: int) -> wandb.Histogram: """Records a W&B Histogram""" return wandb.Histogram( diff --git a/tests/integration_tests/test_pipeline.py b/tests/integration_tests/test_pipeline.py index 9db094c..3de3e6b 100644 --- a/tests/integration_tests/test_pipeline.py +++ b/tests/integration_tests/test_pipeline.py @@ -2,10 +2,8 @@ from pathlib import Path import lightning as pl -import numpy as np import pytest import torch -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from frdc.models.efficientnetb1 import ( EfficientNetB1MixMatchModule, @@ -33,22 +31,10 @@ def test_manual_segmentation_pipeline(model_fn, ds): val_ds=ds, batch_size=BATCH_SIZE, ) - - oe = OrdinalEncoder( - handle_unknown="use_encoded_value", - unknown_value=np.nan, - ) - oe.fit(np.array(ds.targets).reshape(-1, 1)) - n_classes = len(oe.categories_[0]) - - ss = StandardScaler() - ss.fit(ds.ar.reshape(-1, ds.ar.shape[-1])) - m = model_fn( in_channels=ds.ar.shape[-1], - n_classes=n_classes, lr=1e-3, - y_encoder=oe, + out_targets=ds.targets, frozen=True, ) @@ -92,4 +78,4 @@ def test_manual_segmentation_pipeline(model_fn, ds): # E.g. achieved via hash comparison. # This is because BatchNorm usually keeps running statistics # and reloading the model will reset them. - # We don't necessarily need to + # We don't necessarily need to check for this. diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index 7975883..b5804de 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -94,13 +94,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - m = EfficientNetB1FixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - y_encoder=oe, frozen=True, ) @@ -121,7 +118,7 @@ def main( ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index 1bf839f..7b85260 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -25,8 +25,6 @@ FRDCDatasetStaticEval, n_strong_aug, strong_aug, - get_y_encoder, - get_x_scaler, ) @@ -85,15 +83,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - ss = get_x_scaler(train_lab_ds.ar_segments) - m = EfficientNetB1MixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - x_scaler=ss, - y_encoder=oe, frozen=True, ) @@ -114,7 +107,7 @@ def main( ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") From 13b15930a9b1dabc27ee2f235a02e876ad8bf5b8 Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 10 Jun 2024 17:12:26 +0800 Subject: [PATCH 07/16] Allow StandardScaler to override default scaler --- src/frdc/load/dataset.py | 15 ++++++++++++--- src/frdc/load/preset.py | 19 +++++++++++++------ .../chestnut_dec_may/train_fixmatch.py | 9 ++++++--- .../chestnut_dec_may/train_mixmatch.py | 10 ++++++++-- 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/frdc/load/dataset.py b/src/frdc/load/dataset.py index 9ab6cb6..4666283 100644 --- a/src/frdc/load/dataset.py +++ b/src/frdc/load/dataset.py @@ -73,7 +73,7 @@ def __init__( date: str, version: str | None, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = False, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -97,7 +97,9 @@ def __init__( date: The date of the dataset, e.g. "20201218". version: The version of the dataset, e.g. "183deg". transform: The transform to apply to each segment. - transform_scale: Prepends a scaling transform to the transform. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to each label. use_legacy_bounds: Whether to use the legacy bounds.csv file. This will automatically be set to True if LABEL_STUDIO_CLIENT @@ -129,7 +131,7 @@ def __init__( self.transform = transform self.target_transform = target_transform - if transform_scale: + if transform_scale is True: self.x_scaler = StandardScaler() self.x_scaler.fit( np.concatenate( @@ -146,6 +148,13 @@ def __init__( x.shape ) ) + elif isinstance(transform_scale, StandardScaler): + self.x_scaler = transform_scale + self.transform = lambda x: transform( + self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( + x.shape + ) + ) else: self.x_scaler = None diff --git a/src/frdc/load/preset.py b/src/frdc/load/preset.py index 4d09c88..6fa54df 100644 --- a/src/frdc/load/preset.py +++ b/src/frdc/load/preset.py @@ -6,6 +6,7 @@ import numpy as np import torch +from sklearn.preprocessing import StandardScaler from torchvision.transforms.v2 import ( Compose, ToImage, @@ -48,7 +49,7 @@ class FRDCDatasetPartial: def __call__( self, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = True, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -58,7 +59,9 @@ def __call__( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. @@ -76,7 +79,7 @@ def __call__( def labelled( self, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = True, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -86,7 +89,9 @@ def labelled( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. @@ -107,7 +112,7 @@ def labelled( def unlabelled( self, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = True, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -122,7 +127,9 @@ def unlabelled( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index b5804de..c83bf43 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -26,7 +26,6 @@ val_preprocess, FRDCDatasetStaticEval, n_weak_strong_aug, - get_y_encoder, weak_aug, ) @@ -57,9 +56,12 @@ def main( im_size = 255 train_lab_ds = ds.chestnut_20201218(transform=weak_aug(im_size)) train_unl_ds = ds.chestnut_20201218.unlabelled( - transform=n_weak_strong_aug(im_size, unlabelled_factor) + transform=n_weak_strong_aug(im_size, unlabelled_factor), + ) + val_ds = ds.chestnut_20210510_43m( + transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -115,6 +117,7 @@ def main( "20210510", "90deg43m85pct255deg", transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, ) diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index 7b85260..0aab51f 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -44,11 +44,16 @@ def main( ): # Prepare the dataset im_size = 299 - train_lab_ds = ds.chestnut_20201218(transform=strong_aug(im_size)) + train_lab_ds = ds.chestnut_20201218( + transform=strong_aug(im_size), + ) train_unl_ds = ds.chestnut_20201218.unlabelled( transform=n_strong_aug(im_size, 2) ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) + val_ds = ds.chestnut_20210510_43m( + transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, + ) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -104,6 +109,7 @@ def main( "20210510", "90deg43m85pct255deg", transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, ) From 236c019295c7674e82e2ff755c3341230457ef6f Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 11 Jun 2024 12:08:44 +0800 Subject: [PATCH 08/16] Improve augmentation naming scheme --- .../chestnut_dec_may/train_fixmatch.py | 32 ++++----- .../chestnut_dec_may/train_mixmatch.py | 14 ++-- tests/model_tests/utils.py | 72 +++++++++---------- 3 files changed, 57 insertions(+), 61 deletions(-) diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index c83bf43..86cf0ab 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -23,10 +23,10 @@ from frdc.train.frdc_datamodule import FRDCDataModule from frdc.utils.training import predict, plot_confusion_matrix from model_tests.utils import ( - val_preprocess, + const_weak_aug, FRDCDatasetStaticEval, - n_weak_strong_aug, - weak_aug, + n_rand_weak_strong_aug, + rand_weak_aug, ) @@ -37,15 +37,15 @@ def main( - batch_size=32, - epochs=10, - train_iters=25, - unlabelled_factor=2, - lr=1e-3, - accelerator="gpu", - wandb_active: bool = True, - wandb_name="chestnut_dec_may", - wandb_project="frdc", + batch_size=32, + epochs=10, + train_iters=25, + unlabelled_factor=2, + lr=1e-3, + accelerator="gpu", + wandb_active: bool = True, + wandb_name="chestnut_dec_may", + wandb_project="frdc", ): if not wandb_active: import os @@ -54,12 +54,12 @@ def main( # Prepare the dataset im_size = 255 - train_lab_ds = ds.chestnut_20201218(transform=weak_aug(im_size)) + train_lab_ds = ds.chestnut_20201218(transform=rand_weak_aug(im_size)) train_unl_ds = ds.chestnut_20201218.unlabelled( - transform=n_weak_strong_aug(im_size, unlabelled_factor), + transform=n_rand_weak_strong_aug(im_size, unlabelled_factor), ) val_ds = ds.chestnut_20210510_43m( - transform=val_preprocess(im_size), + transform=const_weak_aug(im_size), transform_scale=train_lab_ds.x_scaler, ) @@ -116,7 +116,7 @@ def main( "chestnut_nature_park", "20210510", "90deg43m85pct255deg", - transform=val_preprocess(im_size), + transform=const_weak_aug(im_size), transform_scale=train_lab_ds.x_scaler, ), model=m, diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index 0aab51f..e696943 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -21,10 +21,10 @@ from frdc.train.frdc_datamodule import FRDCDataModule from frdc.utils.training import predict, plot_confusion_matrix from model_tests.utils import ( - val_preprocess, + const_weak_aug, FRDCDatasetStaticEval, - n_strong_aug, - strong_aug, + n_rand_strong_aug, + rand_strong_aug, ) @@ -45,13 +45,13 @@ def main( # Prepare the dataset im_size = 299 train_lab_ds = ds.chestnut_20201218( - transform=strong_aug(im_size), + transform=rand_strong_aug(im_size), ) train_unl_ds = ds.chestnut_20201218.unlabelled( - transform=n_strong_aug(im_size, 2) + transform=n_rand_strong_aug(im_size, 2) ) val_ds = ds.chestnut_20210510_43m( - transform=val_preprocess(im_size), + transform=const_weak_aug(im_size), transform_scale=train_lab_ds.x_scaler, ) @@ -108,7 +108,7 @@ def main( "chestnut_nature_park", "20210510", "90deg43m85pct255deg", - transform=val_preprocess(im_size), + transform=const_weak_aug(im_size), transform_scale=train_lab_ds.x_scaler, ), model=m, diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index e4816b3..da0aa29 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -65,70 +65,66 @@ def __getitem__(self, idx): return x_, y -def val_preprocess(size: int): - return lambda x: Compose( - [ - ToImage(), - ToDtype(torch.float32, scale=True), - Resize(size, antialias=True), - CenterCrop(size), - ] - )(x) +def n_times(f, n: int): + return lambda x: [f(x) for _ in range(n)] -def n_weak_aug(size, n_aug: int = 2): - return lambda x: ( - [weak_aug(size)(x) for _ in range(n_aug)] if n_aug > 0 else None - ) +def n_rand_weak_aug(size, n_aug: int = 2): + return n_times(rand_weak_aug(size), n_aug) -def n_strong_aug(size, n_aug: int = 2): - return lambda x: ( - [strong_aug(size)(x) for _ in range(n_aug)] if n_aug > 0 else None - ) +def n_rand_strong_aug(size, n_aug: int = 2): + return n_times(rand_strong_aug(size), n_aug) -def n_weak_strong_aug(size, n_aug: int = 2): +def n_rand_weak_strong_aug(size, n_aug: int = 2): def f(x): - x_weak = n_weak_aug(size, n_aug)(x) - x_strong = n_strong_aug(size, n_aug)(x) - return list(zip(*[x_weak, x_strong])) if n_aug > 0 else None + # x_weak = [weak_0, weak_1, ..., weak_n] + x_weak = n_rand_weak_aug(size, n_aug)(x) + # x_strong = [strong_0, strong_1, ..., strong_n] + x_strong = n_rand_strong_aug(size, n_aug)(x) + # x_paired = [(weak_0, strong_0), (weak_1, strong_1), + # ..., (weak_n, strong_n)] + x_paired = list(zip(*[x_weak, x_strong])) + return x_paired return f -def weak_aug(size: int): - return lambda x: Compose( +def rand_weak_aug(size: int): + return Compose( [ ToImage(), ToDtype(torch.float32, scale=True), - Resize(size, antialias=True), - CenterCrop(size), RandomHorizontalFlip(), RandomVerticalFlip(), RandomApply([RandomRotation((90, 90))], p=0.5), + Resize(size, antialias=True), + CenterCrop(size), ] - )(x) + ) -def strong_aug(size: int): - return lambda x: Compose( +def const_weak_aug(size: int): + return Compose( [ ToImage(), ToDtype(torch.float32, scale=True), Resize(size, antialias=True), - RandomCrop(size, pad_if_needed=False), # Strong + CenterCrop(size), + ] + ) + + +def rand_strong_aug(size: int): + return Compose( + [ + ToImage(), + ToDtype(torch.float32, scale=True), RandomHorizontalFlip(), RandomVerticalFlip(), RandomApply([RandomRotation((90, 90))], p=0.5), + Resize(size, antialias=True), + RandomCrop(size, pad_if_needed=False), # Strong ] - )(x) - - -def get_y_encoder(targets): - oe = OrdinalEncoder( - handle_unknown="use_encoded_value", - unknown_value=np.nan, ) - oe.fit(np.array(targets).reshape(-1, 1)) - return oe From b6f68ea84fec88f6970eef1547bbe1fb8704fd3b Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 11 Jun 2024 14:39:55 +0800 Subject: [PATCH 09/16] Migrate Const Rotated Dataset to prefix variant --- src/frdc/load/dataset.py | 39 ++++++++++++++++ src/frdc/load/preset.py | 44 ++++++++++++++++++- .../chestnut_dec_may/train_fixmatch.py | 30 ++++++------- tests/model_tests/utils.py | 43 ------------------ 4 files changed, 96 insertions(+), 60 deletions(-) diff --git a/src/frdc/load/dataset.py b/src/frdc/load/dataset.py index 4666283..4076e1e 100644 --- a/src/frdc/load/dataset.py +++ b/src/frdc/load/dataset.py @@ -10,7 +10,9 @@ import pandas as pd from PIL import Image from sklearn.preprocessing import StandardScaler +from torch import rot90 from torch.utils.data import Dataset, ConcatDataset +from torchvision.transforms.v2.functional import hflip from frdc.conf import ( BAND_CONFIG, @@ -334,3 +336,40 @@ def __getitem__(self, item): if self.transform else self.ar_segments[item] ) + + +class FRDCConstRotatedDataset(FRDCDataset): + def __len__(self): + """Assume that the dataset is 8x larger than it actually is. + + There are 8 possible orientations for each image. + 1. As-is + 2, 3, 4. Rotated 90, 180, 270 degrees + 5. Horizontally flipped + 6, 7, 8. Horizontally flipped and rotated 90, 180, 270 degrees + """ + return super().__len__() * 8 + + def __getitem__(self, idx): + """Alter the getitem method to implement the logic above.""" + x, y = super().__getitem__(int(idx // 8)) + assert x.ndim == 3, "x must be a 3D tensor" + x_ = None + if idx % 8 == 0: + x_ = x + elif idx % 8 == 1: + x_ = rot90(x, 1, (1, 2)) + elif idx % 8 == 2: + x_ = rot90(x, 2, (1, 2)) + elif idx % 8 == 3: + x_ = rot90(x, 3, (1, 2)) + elif idx % 8 == 4: + x_ = hflip(x) + elif idx % 8 == 5: + x_ = hflip(rot90(x, 1, (1, 2))) + elif idx % 8 == 6: + x_ = hflip(rot90(x, 2, (1, 2))) + elif idx % 8 == 7: + x_ = hflip(rot90(x, 3, (1, 2))) + + return x_, y diff --git a/src/frdc/load/preset.py b/src/frdc/load/preset.py index 6fa54df..518dd11 100644 --- a/src/frdc/load/preset.py +++ b/src/frdc/load/preset.py @@ -14,7 +14,11 @@ Resize, ) -from frdc.load.dataset import FRDCDataset, FRDCUnlabelledDataset +from frdc.load.dataset import ( + FRDCDataset, + FRDCUnlabelledDataset, + FRDCConstRotatedDataset, +) logger = logging.getLogger(__name__) @@ -147,6 +151,44 @@ def unlabelled( polycrop_value=polycrop_value, ) + def const_rotated( + self, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, + use_legacy_bounds: bool = False, + polycrop: bool = False, + polycrop_value: Any = np.nan, + ): + """Returns the Unlabelled Dataset. + + Notes: + This simply masks away the labels during __getitem__. + The same behaviour can be achieved by setting __class__ to + FRDCUnlabelledDataset, but this is a more convenient way to do so. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ + return FRDCConstRotatedDataset( + self.site, + self.date, + self.version, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, + use_legacy_bounds=use_legacy_bounds, + polycrop=polycrop, + polycrop_value=polycrop_value, + ) + @dataclass class FRDCDatasetPreset: diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index 86cf0ab..4f01798 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -24,12 +24,13 @@ from frdc.utils.training import predict, plot_confusion_matrix from model_tests.utils import ( const_weak_aug, - FRDCDatasetStaticEval, n_rand_weak_strong_aug, rand_weak_aug, ) +# %% + # Uncomment this to run the W&B monitoring locally # import os # @@ -37,15 +38,15 @@ def main( - batch_size=32, - epochs=10, - train_iters=25, - unlabelled_factor=2, - lr=1e-3, - accelerator="gpu", - wandb_active: bool = True, - wandb_name="chestnut_dec_may", - wandb_project="frdc", + batch_size=32, + epochs=10, + train_iters=25, + unlabelled_factor=2, + lr=1e-3, + accelerator="gpu", + wandb_active: bool = True, + wandb_name="chestnut_dec_may", + wandb_project="frdc", ): if not wandb_active: import os @@ -112,10 +113,7 @@ def main( ) y_true, y_pred = predict( - ds=FRDCDatasetStaticEval( - "chestnut_nature_park", - "20210510", - "90deg43m85pct255deg", + ds=ds.chestnut_20210510_43m.const_rotated( transform=const_weak_aug(im_size), transform_scale=train_lab_ds.x_scaler, ), @@ -133,8 +131,8 @@ def main( if __name__ == "__main__": BATCH_SIZE = 32 - EPOCHS = 10 - TRAIN_ITERS = 25 + EPOCHS = 2 + TRAIN_ITERS = 2 LR = 3e-3 torch.set_float32_matmul_precision("high") diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index da0aa29..a596964 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -2,10 +2,7 @@ from pathlib import Path -import numpy as np import torch -from sklearn.preprocessing import OrdinalEncoder, StandardScaler -from torch import rot90 from torchvision.transforms import RandomVerticalFlip from torchvision.transforms.v2 import ( Compose, @@ -19,52 +16,12 @@ Resize, ) from torchvision.transforms.v2 import RandomHorizontalFlip -from torchvision.transforms.v2.functional import hflip - -from frdc.load.dataset import FRDCDataset THIS_DIR = Path(__file__).parent BANDS = ["NB", "NG", "NR", "RE", "NIR"] -class FRDCDatasetStaticEval(FRDCDataset): - def __len__(self): - """Assume that the dataset is 8x larger than it actually is. - - There are 8 possible orientations for each image. - 1. As-is - 2, 3, 4. Rotated 90, 180, 270 degrees - 5. Horizontally flipped - 6, 7, 8. Horizontally flipped and rotated 90, 180, 270 degrees - """ - return super().__len__() * 8 - - def __getitem__(self, idx): - """Alter the getitem method to implement the logic above.""" - x, y = super().__getitem__(int(idx // 8)) - assert x.ndim == 3, "x must be a 3D tensor" - x_ = None - if idx % 8 == 0: - x_ = x - elif idx % 8 == 1: - x_ = rot90(x, 1, (1, 2)) - elif idx % 8 == 2: - x_ = rot90(x, 2, (1, 2)) - elif idx % 8 == 3: - x_ = rot90(x, 3, (1, 2)) - elif idx % 8 == 4: - x_ = hflip(x) - elif idx % 8 == 5: - x_ = hflip(rot90(x, 1, (1, 2))) - elif idx % 8 == 6: - x_ = hflip(rot90(x, 2, (1, 2))) - elif idx % 8 == 7: - x_ = hflip(rot90(x, 3, (1, 2))) - - return x_, y - - def n_times(f, n: int): return lambda x: [f(x) for _ in range(n)] From a51c2d8bae33dabb096cc5c37326d6d137343d58 Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 11 Jun 2024 14:41:12 +0800 Subject: [PATCH 10/16] Update mixmatch's training to use new variant --- tests/model_tests/chestnut_dec_may/train_mixmatch.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index e696943..edae259 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -16,13 +16,13 @@ ) from lightning.pytorch.loggers import WandbLogger +from frdc.load.dataset import FRDCConstRotatedDataset from frdc.load.preset import FRDCDatasetPreset as ds from frdc.models.efficientnetb1 import EfficientNetB1MixMatchModule from frdc.train.frdc_datamodule import FRDCDataModule from frdc.utils.training import predict, plot_confusion_matrix from model_tests.utils import ( const_weak_aug, - FRDCDatasetStaticEval, n_rand_strong_aug, rand_strong_aug, ) @@ -104,10 +104,7 @@ def main( ) y_true, y_pred = predict( - ds=FRDCDatasetStaticEval( - "chestnut_nature_park", - "20210510", - "90deg43m85pct255deg", + ds=ds.chestnut_20210510_43m.const_rotated( transform=const_weak_aug(im_size), transform_scale=train_lab_ds.x_scaler, ), From 4fd7875df8383e022e82d90286e11aa68e9ce076 Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 11 Jun 2024 14:45:47 +0800 Subject: [PATCH 11/16] Revert epoch and iteration changes for fixmatch --- tests/model_tests/chestnut_dec_may/train_fixmatch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index 4f01798..ac6b66a 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -131,8 +131,8 @@ def main( if __name__ == "__main__": BATCH_SIZE = 32 - EPOCHS = 2 - TRAIN_ITERS = 2 + EPOCHS = 10 + TRAIN_ITERS = 25 LR = 3e-3 torch.set_float32_matmul_precision("high") From aa5292e8d62d024236c3b234b7b8352efe6e96e2 Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 11 Jun 2024 16:08:11 +0800 Subject: [PATCH 12/16] Fix issue that slows down preprocessing --- tests/model_tests/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index a596964..f9d7917 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -14,6 +14,7 @@ RandomRotation, RandomApply, Resize, + RandomResizedCrop, ) from torchvision.transforms.v2 import RandomHorizontalFlip @@ -53,11 +54,12 @@ def rand_weak_aug(size: int): [ ToImage(), ToDtype(torch.float32, scale=True), + RandomResizedCrop( + size, scale=(0.08, 1.0), ratio=(0.95, 1.05), antialias=True + ), RandomHorizontalFlip(), RandomVerticalFlip(), RandomApply([RandomRotation((90, 90))], p=0.5), - Resize(size, antialias=True), - CenterCrop(size), ] ) @@ -78,10 +80,11 @@ def rand_strong_aug(size: int): [ ToImage(), ToDtype(torch.float32, scale=True), + RandomResizedCrop( + size, scale=(0.08, 1.0), ratio=(0.9, 1.1), antialias=True + ), RandomHorizontalFlip(), RandomVerticalFlip(), RandomApply([RandomRotation((90, 90))], p=0.5), - Resize(size, antialias=True), - RandomCrop(size, pad_if_needed=False), # Strong ] ) From 754b27307fcb32bbcfcc0dd0777459509ac67352 Mon Sep 17 00:00:00 2001 From: Evening Date: Thu, 13 Jun 2024 17:24:09 +0800 Subject: [PATCH 13/16] Improve naming scheme for label_studio module --- src/frdc/load/label_studio.py | 51 ++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/src/frdc/load/label_studio.py b/src/frdc/load/label_studio.py index 8fd5ec2..435cc40 100644 --- a/src/frdc/load/label_studio.py +++ b/src/frdc/load/label_studio.py @@ -16,39 +16,58 @@ def get_bounds_and_labels(self) -> tuple[list[tuple[int, int]], list[str]]: bounds = [] labels = [] - ann = self["annotations"][0] - results = ann["result"] - for r_ix, r in enumerate(results): - r: dict + # Each annotation is an entire image labelled by a single person. + # By selecting the 0th annotation, we are usually selecting the main + # annotation. + annotation = self["annotations"][0] + + # There are some metadata in `annotation`, but we just want the results + results = annotation["result"] + + for bbox_ix, bbox in enumerate(results): + # 'id' = {str} 'jr4EXAKAV8' + # 'type' = {str} 'polygonlabels' + # 'value' = {dict: 3} { + # 'closed': True, + # 'points': [[x0, y0], [x1, y1], ... [xn, yn]], + # 'polygonlabels': ['label'] + # } + # 'origin' = {str} 'manual' + # 'to_name' = {str} 'image' + # 'from_name' = {str} 'label' + # 'image_rotation' = {int} 0 + # 'original_width' = {int} 450 + # 'original_height' = {int} 600 + bbox: dict # See Issue FRML-78: Somehow some labels are actually just metadata - if r["from_name"] != "label": + if bbox["from_name"] != "label": continue # We flatten the value dict into the result dict - v = r.pop("value") - r = {**r, **v} + v = bbox.pop("value") + bbox = {**bbox, **v} # Points are in percentage, we need to convert them to pixels - r["points"] = [ + bbox["points"] = [ ( - int(x * r["original_width"] / 100), - int(y * r["original_height"] / 100), + int(x * bbox["original_width"] / 100), + int(y * bbox["original_height"] / 100), ) - for x, y in r["points"] + for x, y in bbox["points"] ] # Only take the first label as this is not a multi-label task - r["label"] = r.pop("polygonlabels")[0] - if not r["closed"]: + bbox["label"] = bbox.pop("polygonlabels")[0] + if not bbox["closed"]: logger.warning( - f"Label for {r['label']} @ {r['points']} not closed. " + f"Label for {bbox['label']} @ {bbox['points']} not closed. " f"Skipping" ) continue - bounds.append(r["points"]) - labels.append(r["label"]) + bounds.append(bbox["points"]) + labels.append(bbox["label"]) return bounds, labels From 367f6bf5af460c1e53a85b0be0a2c562e5167ad4 Mon Sep 17 00:00:00 2001 From: Evening Date: Thu, 13 Jun 2024 17:24:32 +0800 Subject: [PATCH 14/16] Add spin up dependencies and script for replica --- .../label-studio-replica/default_config.xml | 121 ++++++++++++++++++ .../label-studio-replica/docker-compose.yml | 78 +++++++++++ .../initialize_replica.py | 99 ++++++++++++++ 3 files changed, 298 insertions(+) create mode 100644 src/label-studio/label-studio-replica/default_config.xml create mode 100644 src/label-studio/label-studio-replica/docker-compose.yml create mode 100644 src/label-studio/label-studio-replica/initialize_replica.py diff --git a/src/label-studio/label-studio-replica/default_config.xml b/src/label-studio/label-studio-replica/default_config.xml new file mode 100644 index 0000000..051211f --- /dev/null +++ b/src/label-studio/label-studio-replica/default_config.xml @@ -0,0 +1,121 @@ + +
Replica FRDC Server
+
+ This is a replica server. All changes will NOT be reflected onto the Machine Learning Pipeline. +
+ +
Select Species
+ + + + + + + + +
Select Quality
+ + + + + + + + + + + + + + + + +
Submitted By (Team):
+ + + + +
UserID (Submit):
+ + + + + + + + + +
+ +
Checked By (Team):
+ + + + +
UserID (Check):
+ +
+ +
+ +
\ No newline at end of file diff --git a/src/label-studio/label-studio-replica/docker-compose.yml b/src/label-studio/label-studio-replica/docker-compose.yml new file mode 100644 index 0000000..80bc06c --- /dev/null +++ b/src/label-studio/label-studio-replica/docker-compose.yml @@ -0,0 +1,78 @@ +version: "3.9" +services: + nginx: + build: . + image: heartexlabs/label-studio:latest + restart: unless-stopped + ports: + - "8082:8085" + - "8083:8086" + depends_on: + - app + environment: + - LABEL_STUDIO_HOST=${LABEL_STUDIO_HOST:-} + # Optional: Specify SSL termination certificate & key + # Just drop your cert.pem and cert.key into folder 'deploy/nginx/certs' + # - NGINX_SSL_CERT=/certs/cert.pem + # - NGINX_SSL_CERT_KEY=/certs/cert.key + volumes: + - ./mydata:/label-studio/data:rw + - ./deploy/nginx/certs:/certs:ro + # Optional: Override nginx default conf + # - ./deploy/my.conf:/etc/nginx/nginx.conf + command: nginx + networks: + - label-studio-dev + + app: + stdin_open: true + tty: true + build: . + image: heartexlabs/label-studio:latest + restart: unless-stopped + expose: + - "8000" + depends_on: + - db-dev + environment: + - DJANGO_DB=default + - POSTGRE_NAME=postgres + - POSTGRE_USER=postgres + - POSTGRE_PASSWORD= + - POSTGRE_PORT=5432 + - POSTGRE_HOST=db-dev + - LABEL_STUDIO_HOST=${LABEL_STUDIO_HOST:-} + - JSON_LOG=1 + # - LOG_LEVEL=DEBUG + volumes: + - ./mydata:/label-studio/data:rw + command: label-studio-uwsgi + networks: + - label-studio-dev + + db-dev: + image: postgres:11.5 + hostname: db-dev + restart: unless-stopped + # Optional: Enable TLS on PostgreSQL + # Just drop your server.crt and server.key into folder 'deploy/pgsql/certs' + # NOTE: Both files must have permissions u=rw (0600) or less + # command: > + # -c ssl=on + # -c ssl_cert_file=/var/lib/postgresql/certs/server.crt + # -c ssl_key_file=/var/lib/postgresql/certs/server.key + ports: + - "5435:5432" + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + volumes: + - ${POSTGRES_DATA_DIR:-./postgres-data}:/var/lib/postgresql/data + - ${POSTGRES_DATA_DIR:-./postgres-backups}:/var/lib/postgresql/backups + - ./deploy/pgsql/certs:/var/lib/postgresql/certs:ro + networks: + - label-studio-dev + +networks: + label-studio-dev: + name: label-studio-dev + driver: bridge diff --git a/src/label-studio/label-studio-replica/initialize_replica.py b/src/label-studio/label-studio-replica/initialize_replica.py new file mode 100644 index 0000000..71fc30b --- /dev/null +++ b/src/label-studio/label-studio-replica/initialize_replica.py @@ -0,0 +1,99 @@ +import os +import time +from pathlib import Path + +import label_studio_sdk + +THIS_DIR = Path(__file__).parent + +# This is your API token. I put mine here, which is OK only if you're in a +# development environment. Otherwise, do not. +dev_api_key = os.getenv("REPLICA_LABEL_STUDIO_API_KEY") +prd_api_key = os.getenv("LABEL_STUDIO_API_KEY") +dev_url = "http://localhost:8082" +prd_url = "http://localhost:8080" + +# We can initialize the sdk using this following. +# The client is like the middleman between you as a programmer, and the +# Label Studio (LS) server. +dev_client = label_studio_sdk.Client(url=dev_url, api_key=dev_api_key) +prd_client = label_studio_sdk.Client(url=prd_url, api_key=prd_api_key) + +# This is the labelling interface configuration. +# We can save it somewhere as an XML file then import it too +dev_config = (THIS_DIR / "default_config.xml").read_text() + +# %% +print("Creating Development Project...") +# Creates the project, note to set the config here +dev_proj = dev_client.create_project( + title="FRDC Replica", + description="This is the replica project of FRDC. It's ok to break this.", + label_config=dev_config, + color="#FF0025", +) +# %% +print("Adding Import Source...") +# This links to our GCS as an import source +dev_storage = dev_proj.connect_google_import_storage( + bucket="frdc-ds", + regex_filter=".*.jpg", + google_application_credentials=( + THIS_DIR / "frmodel-943e4feae446.json" + ).read_text(), + presign=False, + title="Source", +) +time.sleep(5) +# %% +print("Syncing Storage...") +# Then, we sync it so that all the images appear as annotation targets +dev_proj.sync_storage( + storage_type=dev_storage["type"], + storage_id=dev_storage["id"], +) +time.sleep(5) +# %% +print("Retrieving Tasks...") +prd_proj = prd_client.get_project(id=1) +prd_tasks = prd_proj.get_tasks() +dev_tasks = dev_proj.get_tasks() +# %% +# This step copies over the annotations from the production to the development +# This creates it as a "prediction" +print("Copying Annotations...") +for prd_task in prd_tasks: + # For each prod task, we find the corresponding (image) file name + prd_fn = prd_task["storage_filename"] + + # Then, we find the corresponding task in the development project + dev_tasks_matched = [ + t for t in dev_tasks if t["storage_filename"] == prd_fn + ] + + # Do some error handling + if len(dev_tasks_matched) == 0: + print(f"File not found in dev: {prd_fn}") + continue + if len(dev_tasks_matched) > 1: + print(f"Too many matches found in dev: {prd_fn}") + continue + + # Get the first match + dev_task = dev_tasks_matched[0] + + # Only get annotations by evening + prd_ann = [ + ann + for ann in prd_task["annotations"] + if "dev_evening" in ann["created_username"] + ][0] + + # Create the prediction using the result from production + dev_proj.create_prediction( + task_id=dev_task["id"], + result=prd_ann["result"], + model_version="API Testing Prediction", + ) + +print("Done!") From fd6ca408e1c8c5500dd279629d26a6219150fc45 Mon Sep 17 00:00:00 2001 From: Evening Date: Wed, 19 Jun 2024 13:54:24 +0800 Subject: [PATCH 15/16] Revert Augmentation Change This caused significant degradation in performance. Reverting it for original behavior unless proved to be better --- tests/model_tests/utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index f9d7917..0cbb4ef 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -14,7 +14,6 @@ RandomRotation, RandomApply, Resize, - RandomResizedCrop, ) from torchvision.transforms.v2 import RandomHorizontalFlip @@ -54,9 +53,8 @@ def rand_weak_aug(size: int): [ ToImage(), ToDtype(torch.float32, scale=True), - RandomResizedCrop( - size, scale=(0.08, 1.0), ratio=(0.95, 1.05), antialias=True - ), + Resize(size, antialias=True), + CenterCrop(size), RandomHorizontalFlip(), RandomVerticalFlip(), RandomApply([RandomRotation((90, 90))], p=0.5), @@ -80,9 +78,8 @@ def rand_strong_aug(size: int): [ ToImage(), ToDtype(torch.float32, scale=True), - RandomResizedCrop( - size, scale=(0.08, 1.0), ratio=(0.9, 1.1), antialias=True - ), + Resize(size, antialias=True), + RandomCrop(size, pad_if_needed=False), RandomHorizontalFlip(), RandomVerticalFlip(), RandomApply([RandomRotation((90, 90))], p=0.5), From b208a5f9a8a015ac333479553be8ef3d6011fafc Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Wed, 19 Jun 2024 16:59:18 +0800 Subject: [PATCH 16/16] Force rerun --- .github/workflows/model-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/model-tests.yml b/.github/workflows/model-tests.yml index a04f378..ef88ebd 100644 --- a/.github/workflows/model-tests.yml +++ b/.github/workflows/model-tests.yml @@ -85,8 +85,8 @@ jobs: working-directory: ${{ github.workspace }}/tests run: | git config --global --add safe.directory /__w/FRDC-ML/FRDC-ML - python3 -m model_tests.chestnut_dec_may.train_mixmatch python3 -m model_tests.chestnut_dec_may.train_fixmatch + python3 -m model_tests.chestnut_dec_may.train_mixmatch - name: Comment results via CML run: |