diff --git a/.github/workflows/model-tests.yml b/.github/workflows/model-tests.yml index a04f3780..ef88ebdd 100644 --- a/.github/workflows/model-tests.yml +++ b/.github/workflows/model-tests.yml @@ -85,8 +85,8 @@ jobs: working-directory: ${{ github.workspace }}/tests run: | git config --global --add safe.directory /__w/FRDC-ML/FRDC-ML - python3 -m model_tests.chestnut_dec_may.train_mixmatch python3 -m model_tests.chestnut_dec_may.train_fixmatch + python3 -m model_tests.chestnut_dec_may.train_mixmatch - name: Comment results via CML run: | diff --git a/src/frdc/load/dataset.py b/src/frdc/load/dataset.py index e0a0d9c0..4076e1e9 100644 --- a/src/frdc/load/dataset.py +++ b/src/frdc/load/dataset.py @@ -9,7 +9,10 @@ import numpy as np import pandas as pd from PIL import Image +from sklearn.preprocessing import StandardScaler +from torch import rot90 from torch.utils.data import Dataset, ConcatDataset +from torchvision.transforms.v2.functional import hflip from frdc.conf import ( BAND_CONFIG, @@ -71,8 +74,9 @@ def __init__( site: str, date: str, version: str | None, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, @@ -95,6 +99,9 @@ def __init__( date: The date of the dataset, e.g. "20201218". version: The version of the dataset, e.g. "183deg". transform: The transform to apply to each segment. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to each label. use_legacy_bounds: Whether to use the legacy bounds.csv file. This will automatically be set to True if LABEL_STUDIO_CLIENT @@ -102,6 +109,7 @@ def __init__( to. polycrop: Whether to further crop the segments via its polygon bounds. The cropped area will be padded with np.nan. + polycrop_value: The value to pad the cropped area with. """ self.site = site self.date = date @@ -125,17 +133,40 @@ def __init__( self.transform = transform self.target_transform = target_transform + if transform_scale is True: + self.x_scaler = StandardScaler() + self.x_scaler.fit( + np.concatenate( + [ + # Segments: [H x W x C] -> [H*W, C] + # Reshaping is necessary for StandardScaler + segm.reshape(-1, segm.shape[-1]) + for segm in self.ar_segments + ] + ) + ) + self.transform = lambda x: transform( + self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( + x.shape + ) + ) + elif isinstance(transform_scale, StandardScaler): + self.x_scaler = transform_scale + self.transform = lambda x: transform( + self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( + x.shape + ) + ) + else: + self.x_scaler = None + def __len__(self): return len(self.ar_segments) def __getitem__(self, idx): return ( - self.transform(self.ar_segments[idx]) - if self.transform - else self.ar_segments[idx], - self.target_transform(self.targets[idx]) - if self.target_transform - else self.targets[idx], + self.transform(self.ar_segments[idx]), + self.target_transform(self.targets[idx]), ) @property @@ -305,3 +336,40 @@ def __getitem__(self, item): if self.transform else self.ar_segments[item] ) + + +class FRDCConstRotatedDataset(FRDCDataset): + def __len__(self): + """Assume that the dataset is 8x larger than it actually is. + + There are 8 possible orientations for each image. + 1. As-is + 2, 3, 4. Rotated 90, 180, 270 degrees + 5. Horizontally flipped + 6, 7, 8. Horizontally flipped and rotated 90, 180, 270 degrees + """ + return super().__len__() * 8 + + def __getitem__(self, idx): + """Alter the getitem method to implement the logic above.""" + x, y = super().__getitem__(int(idx // 8)) + assert x.ndim == 3, "x must be a 3D tensor" + x_ = None + if idx % 8 == 0: + x_ = x + elif idx % 8 == 1: + x_ = rot90(x, 1, (1, 2)) + elif idx % 8 == 2: + x_ = rot90(x, 2, (1, 2)) + elif idx % 8 == 3: + x_ = rot90(x, 3, (1, 2)) + elif idx % 8 == 4: + x_ = hflip(x) + elif idx % 8 == 5: + x_ = hflip(rot90(x, 1, (1, 2))) + elif idx % 8 == 6: + x_ = hflip(rot90(x, 2, (1, 2))) + elif idx % 8 == 7: + x_ = hflip(rot90(x, 3, (1, 2))) + + return x_, y diff --git a/src/frdc/load/label_studio.py b/src/frdc/load/label_studio.py index 5486e920..435cc40f 100644 --- a/src/frdc/load/label_studio.py +++ b/src/frdc/load/label_studio.py @@ -16,41 +16,58 @@ def get_bounds_and_labels(self) -> tuple[list[tuple[int, int]], list[str]]: bounds = [] labels = [] - # for ann_ix, ann in enumerate(self["annotations"]): - - ann = self["annotations"][0] - results = ann["result"] - for r_ix, r in enumerate(results): - r: dict + # Each annotation is an entire image labelled by a single person. + # By selecting the 0th annotation, we are usually selecting the main + # annotation. + annotation = self["annotations"][0] + + # There are some metadata in `annotation`, but we just want the results + results = annotation["result"] + + for bbox_ix, bbox in enumerate(results): + # 'id' = {str} 'jr4EXAKAV8' + # 'type' = {str} 'polygonlabels' + # 'value' = {dict: 3} { + # 'closed': True, + # 'points': [[x0, y0], [x1, y1], ... [xn, yn]], + # 'polygonlabels': ['label'] + # } + # 'origin' = {str} 'manual' + # 'to_name' = {str} 'image' + # 'from_name' = {str} 'label' + # 'image_rotation' = {int} 0 + # 'original_width' = {int} 450 + # 'original_height' = {int} 600 + bbox: dict # See Issue FRML-78: Somehow some labels are actually just metadata - if r["from_name"] != "label": + if bbox["from_name"] != "label": continue # We flatten the value dict into the result dict - v = r.pop("value") - r = {**r, **v} + v = bbox.pop("value") + bbox = {**bbox, **v} # Points are in percentage, we need to convert them to pixels - r["points"] = [ + bbox["points"] = [ ( - int(x * r["original_width"] / 100), - int(y * r["original_height"] / 100), + int(x * bbox["original_width"] / 100), + int(y * bbox["original_height"] / 100), ) - for x, y in r["points"] + for x, y in bbox["points"] ] # Only take the first label as this is not a multi-label task - r["label"] = r.pop("polygonlabels")[0] - if not r["closed"]: + bbox["label"] = bbox.pop("polygonlabels")[0] + if not bbox["closed"]: logger.warning( - f"Label for {r['label']} @ {r['points']} not closed. " + f"Label for {bbox['label']} @ {bbox['points']} not closed. " f"Skipping" ) continue - bounds.append(r["points"]) - labels.append(r["label"]) + bounds.append(bbox["points"]) + labels.append(bbox["label"]) return bounds, labels @@ -60,24 +77,15 @@ def get_task( project_id: int = 1, ): proj = LABEL_STUDIO_CLIENT.get_project(project_id) - # Get the task that has the file name - filter = Filters.create( - Filters.AND, - [ - Filters.item( - # The GS path is in the image column, so we can just filter on that - Column.data("image"), - Operator.CONTAINS, - Type.String, - Path(file_name).as_posix(), - ) - ], - ) - tasks = proj.get_tasks(filter) - - if len(tasks) > 1: + task_ids = [ + task["id"] + for task in proj.get_tasks() + if file_name.as_posix() in task["storage_filename"] + ] + + if len(task_ids) > 1: warn(f"More than 1 task found for {file_name}, using the first one") - elif len(tasks) == 0: + elif len(task_ids) == 0: raise ValueError(f"No task found for {file_name}") - return Task(tasks[0]) + return Task(proj.get_task(task_ids[0])) diff --git a/src/frdc/load/preset.py b/src/frdc/load/preset.py index da0aedcb..518dd119 100644 --- a/src/frdc/load/preset.py +++ b/src/frdc/load/preset.py @@ -6,6 +6,7 @@ import numpy as np import torch +from sklearn.preprocessing import StandardScaler from torchvision.transforms.v2 import ( Compose, ToImage, @@ -13,7 +14,11 @@ Resize, ) -from frdc.load.dataset import FRDCDataset, FRDCUnlabelledDataset +from frdc.load.dataset import ( + FRDCDataset, + FRDCUnlabelledDataset, + FRDCConstRotatedDataset, +) logger = logging.getLogger(__name__) @@ -47,15 +52,28 @@ class FRDCDatasetPartial: def __call__( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, ): - """Alias for labelled().""" + """Alias for labelled(). + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ return self.labelled( transform, + transform_scale, target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, @@ -64,19 +82,32 @@ def __call__( def labelled( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, ): - """Returns the Labelled Dataset.""" + """Returns the Labelled Dataset. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ return FRDCDataset( self.site, self.date, self.version, - transform, - target_transform, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, polycrop_value=polycrop_value, @@ -84,8 +115,9 @@ def labelled( def unlabelled( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, @@ -96,13 +128,62 @@ def unlabelled( This simply masks away the labels during __getitem__. The same behaviour can be achieved by setting __class__ to FRDCUnlabelledDataset, but this is a more convenient way to do so. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. """ return FRDCUnlabelledDataset( self.site, self.date, self.version, - transform, - target_transform, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, + use_legacy_bounds=use_legacy_bounds, + polycrop=polycrop, + polycrop_value=polycrop_value, + ) + + def const_rotated( + self, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, + use_legacy_bounds: bool = False, + polycrop: bool = False, + polycrop_value: Any = np.nan, + ): + """Returns the Unlabelled Dataset. + + Notes: + This simply masks away the labels during __getitem__. + The same behaviour can be achieved by setting __class__ to + FRDCUnlabelledDataset, but this is a more convenient way to do so. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ + return FRDCConstRotatedDataset( + self.site, + self.date, + self.version, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, polycrop_value=polycrop_value, @@ -167,5 +248,4 @@ class FRDCDatasetPreset: Resize((resize, resize)), ] ), - target_transform=None, ) diff --git a/src/frdc/models/efficientnetb1.py b/src/frdc/models/efficientnetb1.py index 28276aa0..ae87d8d1 100644 --- a/src/frdc/models/efficientnetb1.py +++ b/src/frdc/models/efficientnetb1.py @@ -1,8 +1,7 @@ from copy import deepcopy -from typing import Dict, Any +from typing import Sequence import torch -from sklearn.preprocessing import OrdinalEncoder, StandardScaler from torch import nn from torchvision.models import ( EfficientNet, @@ -10,7 +9,6 @@ EfficientNet_B1_Weights, ) -from frdc.models.utils import on_save_checkpoint, on_load_checkpoint from frdc.train.fixmatch_module import FixMatchModule from frdc.train.mixmatch_module import MixMatchModule from frdc.utils.ema import EMA @@ -81,10 +79,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, ema_lr: float = 0.001, weight_decay: float = 1e-5, frozen: bool = True, @@ -93,10 +89,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - x_scaler: The X input StandardScaler. - y_encoder: The Y input OrdinalEncoder. ema_lr: The learning rate for the EMA model. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -108,16 +102,14 @@ def __init__( self.weight_decay = weight_decay super().__init__( - n_classes=n_classes, - x_scaler=x_scaler, - y_encoder=y_encoder, + out_targets=out_targets, sharpen_temp=0.5, mix_beta_alpha=0.75, ) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) @@ -139,7 +131,6 @@ def update_ema(self): self.ema_updater.update(self.ema_lr) def forward(self, x: torch.Tensor): - """Forward pass.""" return self.fc(self.eff(x)) def configure_optimizers(self): @@ -147,21 +138,6 @@ def configure_optimizers(self): self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # TODO: MixMatch's saving is a bit complicated due to the dependency - # on the EMA model. This only saves the FC for both the - # main model and the EMA model. - # This may be the reason certain things break when loading - if checkpoint["hyper_parameters"]["frozen"]: - on_save_checkpoint( - self, - checkpoint, - saved_module_prefixes=("_ema_model.fc.", "fc."), - ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - on_load_checkpoint(self, checkpoint) - class EfficientNetB1FixMatchModule(FixMatchModule): MIN_SIZE = 255 @@ -171,10 +147,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, weight_decay: float = 1e-5, frozen: bool = True, ): @@ -182,10 +156,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - x_scaler: The X input StandardScaler. - y_encoder: The Y input OrdinalEncoder. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -195,35 +167,19 @@ def __init__( self.lr = lr self.weight_decay = weight_decay - super().__init__( - n_classes=n_classes, - x_scaler=x_scaler, - y_encoder=y_encoder, - ) + super().__init__(out_targets=out_targets) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) def forward(self, x: torch.Tensor): - """Forward pass.""" return self.fc(self.eff(x)) def configure_optimizers(self): return torch.optim.Adam( self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if checkpoint["hyper_parameters"]["frozen"]: - on_save_checkpoint( - self, - checkpoint, - saved_module_prefixes=("fc.",), - ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - on_load_checkpoint(self, checkpoint) diff --git a/src/frdc/models/utils.py b/src/frdc/models/utils.py index 2b75e1ce..a3b94b33 100644 --- a/src/frdc/models/utils.py +++ b/src/frdc/models/utils.py @@ -1,15 +1,12 @@ -from typing import Dict, Any, Sequence +import logging +import warnings +from typing import Dict, Any, Sequence, Callable -def on_save_checkpoint( +def save_unfrozen( self, checkpoint: Dict[str, Any], - saved_module_prefixes: Sequence[str] = ("fc.",), - saved_module_suffixes: Sequence[str] = ( - "running_mean", - "running_var", - "num_batches_tracked", - ), + include_also: Callable[[str], bool] = lambda k: False, ) -> None: """Saving only the classifier if frozen. @@ -20,29 +17,54 @@ def on_save_checkpoint( This usually reduces the model size by 99.9%, so it's worth it. - By default, this will save the classifier and the BatchNorm running - statistics. + By default, this will save any parameter that requires grad + and the BatchNorm running statistics. Args: self: Not used, but kept for consistency with on_load_checkpoint. checkpoint: The checkpoint to save. - saved_module_prefixes: The prefixes of the modules to save. - saved_module_suffixes: The suffixes of the modules to save. + include_also: A function that returns whether to include a parameter, + on top of any parameter that requires grad and BatchNorm running + statistics. """ - if checkpoint["hyper_parameters"]["frozen"]: - # Keep only the classifier - checkpoint["state_dict"] = { - k: v - for k, v in checkpoint["state_dict"].items() - if ( - k.startswith(saved_module_prefixes) - or k.endswith(saved_module_suffixes) - ) - } + # Keep only the classifier + new_state_dict = {} + for k, v in checkpoint["state_dict"].items(): + # We keep 2 things, + # 1. The BatchNorm running statistics + # 2. Anything that requires grad -def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # BatchNorm running statistics should be kept + # for closer reconstruction of the model + is_bn_var = k.endswith( + ("running_mean", "running_var", "num_batches_tracked") + ) + try: + # We need to retrieve it from the original model + # as the state dict already freezes the model + is_required_grad = self.get_parameter(k).requires_grad + except AttributeError: + if not is_bn_var: + warnings.warn( + f"Unknown non-parameter key in state_dict. {k}." + f"This is an edge case where it's not a parameter nor " + f"BatchNorm running statistics. This will still be saved." + ) + is_required_grad = True + + # These are additional parameters to keep + is_include = include_also(k) + + if is_required_grad or is_bn_var or is_include: + logging.debug(f"Keeping {k}") + new_state_dict[k] = v + + checkpoint["state_dict"] = new_state_dict + + +def load_checkpoint_lenient(self, checkpoint: Dict[str, Any]) -> None: """Loading only the classifier if frozen Notes: diff --git a/src/frdc/train/fixmatch_module.py b/src/frdc/train/fixmatch_module.py index 67e788d3..f29316c0 100644 --- a/src/frdc/train/fixmatch_module.py +++ b/src/frdc/train/fixmatch_module.py @@ -1,28 +1,23 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder +from lightning.pytorch.utilities.types import STEP_OUTPUT from torchmetrics.functional import accuracy -from frdc.train.utils import ( - wandb_hist, - preprocess, -) +from frdc.train.frdc_module import FRDCModule +from frdc.train.utils import wandb_hist -class FixMatchModule(LightningModule): +class FixMatchModule(FRDCModule): def __init__( self, *, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], unl_conf_threshold: float = 0.95, ): """PyTorch Lightning Module for MixMatch @@ -39,18 +34,13 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. - x_scaler: The StandardScaler to use for the data. - y_encoder: The OrdinalEncoder to use for the labels. + out_targets: The output targets for the model. unl_conf_threshold: The confidence threshold for unlabelled data to be considered correctly labelled. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.x_scaler = x_scaler - self.y_encoder = y_encoder - self.n_classes = n_classes self.unl_conf_threshold = unl_conf_threshold self.save_hyperparameters() @@ -62,7 +52,11 @@ def __init__( def forward(self, x): ... - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): """A single training step for a batch Notes: @@ -92,7 +86,7 @@ def training_step(self, batch, batch_idx): ℓ Loss: ℓ_lbl + ℓ_unl """ - def training_step(self, batch, batch_idx): + (x_lbl, y_lbl), x_unls = batch opt = self.optimizers() @@ -173,7 +167,11 @@ def training_step(self, batch, batch_idx): } ) - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) @@ -195,7 +193,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ) -> STEP_OUTPUT: # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch y_pred = self(x) @@ -208,7 +210,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ): (x, y), _x_unls = batch y_pred = self(x) y_true_str = self.y_encoder.inverse_transform( @@ -218,31 +223,3 @@ def predict_step(self, batch, *args, **kwargs) -> Any: y_pred.argmax(dim=1).cpu().numpy().reshape(-1, 1) ) return y_true_str, y_pred_str - - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - x_scaler=self.x_scaler, - y_encoder=self.y_encoder, - x_unl=x_unl, - ) diff --git a/src/frdc/train/frdc_module.py b/src/frdc/train/frdc_module.py new file mode 100644 index 00000000..22582e92 --- /dev/null +++ b/src/frdc/train/frdc_module.py @@ -0,0 +1,143 @@ +from typing import Any, Dict, Sequence + +import numpy as np +import torch +from lightning import LightningModule +from sklearn.preprocessing import OrdinalEncoder + +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient +from frdc.utils.utils import fn_recursive + + +class FRDCModule(LightningModule): + def __init__( + self, + *, + out_targets: Sequence[str], + nan_mask_missing_y_labels: bool = True, + ): + """Base Lightning Module for MixMatch + + Notes: + This is the base class for MixMatch and FixMatch. + This implements the Y-Encoder logic so that all modules can + encode and decode the tree string labels. + + Generally the hierarchy is: + Module + -> Module + -> FRDCModule + + E.g. + EfficientNetB1MixMatchModule + -> MixMatchModule + -> FRDCModule + + WideResNetFixMatchModule + -> FixMatchModule + -> FRDCModule + + Args: + out_targets: The output targets for the model. + nan_mask_missing_y_labels: Whether to mask away x values that + have missing y labels. This happens when the y label is not + present in the OrdinalEncoder's categories, which happens + during non-training steps. E.g. A new unseen tree is inferred. + """ + + super().__init__() + + self.y_encoder: OrdinalEncoder = OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=np.nan, + ) + self.y_encoder.fit(np.array(out_targets).reshape(-1, 1)) + self.nan_mask_missing_y_labels = nan_mask_missing_y_labels + self.save_hyperparameters() + + @property + def n_classes(self): + return len(self.y_encoder.categories_[0]) + + @torch.no_grad() + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + """This method is called before any data transfer to the device. + + Notes: + This method wraps OrdinalEncoder to convert labels from str to int + before transferring the data to the device. + + Note that this step must happen before the transfer as tensors + don't support str types. + + PyTorch Lightning may complain about this being on the Module + instead of the DataModule. However, this is intentional as we + want to export the model alongside the transformations. + """ + + if self.training: + (x_lbl, y_lbl), x_unl = batch + else: + x_lbl, y_lbl = batch + x_unl = [] + + y_trans = torch.from_numpy( + self.y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] + ) + + # Remove nan values from the batch + # Ordinal Encoders can return a np.nan if the value is not in the + # categories. We will remove that from the batch. + nan = ( + ~torch.isnan(y_trans) # Keeps all non-nan values + if self.nan_mask_missing_y_labels + else torch.ones_like(y_trans).bool() # Keeps all values + ) + + x_lbl_trans = torch.nan_to_num(x_lbl[nan]) + + # This function applies nan_to_num to all tensors in the list, + # regardless of how deeply nested they are. + x_unl_trans = fn_recursive( + x_unl, + fn=lambda x: torch.nan_to_num(x[nan]), + type_atom=torch.Tensor, + type_list=list, + ) + y_trans = y_trans[nan].long() + + return (x_lbl_trans, y_trans), x_unl_trans + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_unfrozen(self, checkpoint) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + load_checkpoint_lenient(self, checkpoint) + + # The following methods are to enforce the batch schema typing. + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: + ... diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 300d25e5..3b857851 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -1,31 +1,28 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Any, Dict, Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torch.nn.functional import one_hot from torchmetrics.functional import accuracy +from frdc.models.utils import save_unfrozen +from frdc.train.frdc_module import FRDCModule from frdc.train.utils import ( mix_up, sharpen, wandb_hist, - preprocess, ) -class MixMatchModule(LightningModule): +class MixMatchModule(FRDCModule): def __init__( self, *, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], sharpen_temp: float = 0.5, mix_beta_alpha: float = 0.75, ): @@ -43,17 +40,14 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. + out_targets: The output targets for the model. sharpen_temp: The temperature to use for sharpening. mix_beta_alpha: The alpha to use for the beta distribution when mixing. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.x_scaler = x_scaler - self.y_encoder = y_encoder - self.n_classes = n_classes self.sharpen_temp = sharpen_temp self.mix_beta_alpha = mix_beta_alpha self.save_hyperparameters() @@ -113,7 +107,11 @@ def progress(self): self.global_step / self.trainer.num_training_batches ) / self.trainer.max_epochs - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x_lbl, y_lbl), x_unls = batch self.log("train/x_lbl_mean", x_lbl.mean()) @@ -190,7 +188,11 @@ def training_step(self, batch, batch_idx): def on_after_backward(self) -> None: self.update_ema() - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) y_pred = self.ema_model(x) @@ -210,7 +212,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch y_pred = self.ema_model(x) loss = F.cross_entropy(y_pred, y.long()) @@ -222,7 +228,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: (x, y), _x_unls = batch y_pred = self.ema_model(x) y_true_str = self.y_encoder.inverse_transform( @@ -233,30 +242,10 @@ def predict_step(self, batch, *args, **kwargs) -> Any: ) return y_true_str, y_pred_str - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - x_scaler=self.x_scaler, - y_encoder=self.y_encoder, - x_unl=x_unl, + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """This override the original method to save the EMAs as well.""" + save_unfrozen( + self, + checkpoint, + include_also=lambda k: k.startswith("_ema_model.fc."), ) diff --git a/src/frdc/train/utils.py b/src/frdc/train/utils.py index 6e85fe7a..b42ce91e 100644 --- a/src/frdc/train/utils.py +++ b/src/frdc/train/utils.py @@ -52,109 +52,6 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: return y_sharp -def preprocess( - x_lbl: torch.Tensor, - y_lbl: torch.Tensor, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, - x_unl: list[torch.Tensor] = None, - nan_mask: bool = True, -) -> tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]]: - """Preprocesses the data - - Notes: - The reason why x and y's preprocessing is coupled is due to the NaN - elimination step. The NaN elimination step is due to unseen labels by y - - fn_recursive is to recursively apply some function to a nested list. - This happens due to unlabelled being a list of tensors. - - Args: - x_lbl: The data to preprocess. - y_lbl: The labels to preprocess. - x_scaler: The StandardScaler to use. - y_encoder: The OrdinalEncoder to use. - x_unl: The unlabelled data to preprocess. - nan_mask: Whether to remove nan values from the batch. - - Returns: - The preprocessed data and labels. - """ - - x_unl = [] if x_unl is None else x_unl - - x_lbl_trans = x_standard_scale(x_scaler, x_lbl) - y_trans = y_encode(y_encoder, y_lbl) - x_unl_trans = fn_recursive( - x_unl, - fn=lambda x: x_standard_scale(x_scaler, x), - type_atom=torch.Tensor, - type_list=list, - ) - - # Remove nan values from the batch - # Ordinal Encoders can return a np.nan if the value is not in the - # categories. We will remove that from the batch. - nan = ( - ~torch.isnan(y_trans) if nan_mask else torch.ones_like(y_trans).bool() - ) - x_lbl_trans = x_lbl_trans[nan] - x_lbl_trans = torch.nan_to_num(x_lbl_trans) - x_unl_trans = fn_recursive( - x_unl_trans, - fn=lambda x: torch.nan_to_num(x[nan]), - type_atom=torch.Tensor, - type_list=list, - ) - y_trans = y_trans[nan] - - return (x_lbl_trans, y_trans.long()), x_unl_trans - - -def x_standard_scale( - x_scaler: StandardScaler, x: torch.Tensor -) -> torch.Tensor: - """Standard scales the data - - Notes: - This is a wrapper around the StandardScaler to handle PyTorch tensors. - - Args: - x_scaler: The StandardScaler to use. - x: The data to standard scale, of shape (B, C, H, W). - """ - # Standard Scaler only accepts (n_samples, n_features), - # so we need to do some fancy reshaping. - # Note that moving dimensions then reshaping is different from just - # reshaping! - - # Move Channel to the last dimension then transform - # B x C x H x W -> B x H x W x C - b, c, h, w = x.shape - x_ss = x_scaler.transform(x.permute(0, 2, 3, 1).reshape(-1, c)) - - # Move Channel back to the second dimension - # B x H x W x C -> B x C x H x W - return torch.nan_to_num( - torch.from_numpy(x_ss.reshape(b, h, w, c)).permute(0, 3, 1, 2).float() - ) - - -def y_encode(y_encoder: OrdinalEncoder, y: torch.Tensor) -> torch.Tensor: - """Encodes the labels - - Notes: - This is a wrapper around the OrdinalEncoder to handle PyTorch tensors. - - Args: - y_encoder: The OrdinalEncoder to use. - y: The labels to encode. - """ - return torch.from_numpy( - y_encoder.transform(np.array(y).reshape(-1, 1))[..., 0] - ) - - def wandb_hist(x: torch.Tensor, num_bins: int) -> wandb.Histogram: """Records a W&B Histogram""" return wandb.Histogram( diff --git a/src/label-studio/label-studio-replica/default_config.xml b/src/label-studio/label-studio-replica/default_config.xml new file mode 100644 index 00000000..051211f4 --- /dev/null +++ b/src/label-studio/label-studio-replica/default_config.xml @@ -0,0 +1,121 @@ + +
Replica FRDC Server
+
+ This is a replica server. All changes will NOT be reflected onto the Machine Learning Pipeline. +
+ +
Select Species
+ + + + + + + + +
Select Quality
+ + + + + + + + + + + + + + + + +
Submitted By (Team):
+ + + + +
UserID (Submit):
+ + + + + + + + + +
+ +
Checked By (Team):
+ + + + +
UserID (Check):
+ +
+ +
+ +
\ No newline at end of file diff --git a/src/label-studio/label-studio-replica/docker-compose.yml b/src/label-studio/label-studio-replica/docker-compose.yml new file mode 100644 index 00000000..80bc06cc --- /dev/null +++ b/src/label-studio/label-studio-replica/docker-compose.yml @@ -0,0 +1,78 @@ +version: "3.9" +services: + nginx: + build: . + image: heartexlabs/label-studio:latest + restart: unless-stopped + ports: + - "8082:8085" + - "8083:8086" + depends_on: + - app + environment: + - LABEL_STUDIO_HOST=${LABEL_STUDIO_HOST:-} + # Optional: Specify SSL termination certificate & key + # Just drop your cert.pem and cert.key into folder 'deploy/nginx/certs' + # - NGINX_SSL_CERT=/certs/cert.pem + # - NGINX_SSL_CERT_KEY=/certs/cert.key + volumes: + - ./mydata:/label-studio/data:rw + - ./deploy/nginx/certs:/certs:ro + # Optional: Override nginx default conf + # - ./deploy/my.conf:/etc/nginx/nginx.conf + command: nginx + networks: + - label-studio-dev + + app: + stdin_open: true + tty: true + build: . + image: heartexlabs/label-studio:latest + restart: unless-stopped + expose: + - "8000" + depends_on: + - db-dev + environment: + - DJANGO_DB=default + - POSTGRE_NAME=postgres + - POSTGRE_USER=postgres + - POSTGRE_PASSWORD= + - POSTGRE_PORT=5432 + - POSTGRE_HOST=db-dev + - LABEL_STUDIO_HOST=${LABEL_STUDIO_HOST:-} + - JSON_LOG=1 + # - LOG_LEVEL=DEBUG + volumes: + - ./mydata:/label-studio/data:rw + command: label-studio-uwsgi + networks: + - label-studio-dev + + db-dev: + image: postgres:11.5 + hostname: db-dev + restart: unless-stopped + # Optional: Enable TLS on PostgreSQL + # Just drop your server.crt and server.key into folder 'deploy/pgsql/certs' + # NOTE: Both files must have permissions u=rw (0600) or less + # command: > + # -c ssl=on + # -c ssl_cert_file=/var/lib/postgresql/certs/server.crt + # -c ssl_key_file=/var/lib/postgresql/certs/server.key + ports: + - "5435:5432" + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + volumes: + - ${POSTGRES_DATA_DIR:-./postgres-data}:/var/lib/postgresql/data + - ${POSTGRES_DATA_DIR:-./postgres-backups}:/var/lib/postgresql/backups + - ./deploy/pgsql/certs:/var/lib/postgresql/certs:ro + networks: + - label-studio-dev + +networks: + label-studio-dev: + name: label-studio-dev + driver: bridge diff --git a/src/label-studio/label-studio-replica/initialize_replica.py b/src/label-studio/label-studio-replica/initialize_replica.py new file mode 100644 index 00000000..71fc30bb --- /dev/null +++ b/src/label-studio/label-studio-replica/initialize_replica.py @@ -0,0 +1,99 @@ +import os +import time +from pathlib import Path + +import label_studio_sdk + +THIS_DIR = Path(__file__).parent + +# This is your API token. I put mine here, which is OK only if you're in a +# development environment. Otherwise, do not. +dev_api_key = os.getenv("REPLICA_LABEL_STUDIO_API_KEY") +prd_api_key = os.getenv("LABEL_STUDIO_API_KEY") +dev_url = "http://localhost:8082" +prd_url = "http://localhost:8080" + +# We can initialize the sdk using this following. +# The client is like the middleman between you as a programmer, and the +# Label Studio (LS) server. +dev_client = label_studio_sdk.Client(url=dev_url, api_key=dev_api_key) +prd_client = label_studio_sdk.Client(url=prd_url, api_key=prd_api_key) + +# This is the labelling interface configuration. +# We can save it somewhere as an XML file then import it too +dev_config = (THIS_DIR / "default_config.xml").read_text() + +# %% +print("Creating Development Project...") +# Creates the project, note to set the config here +dev_proj = dev_client.create_project( + title="FRDC Replica", + description="This is the replica project of FRDC. It's ok to break this.", + label_config=dev_config, + color="#FF0025", +) +# %% +print("Adding Import Source...") +# This links to our GCS as an import source +dev_storage = dev_proj.connect_google_import_storage( + bucket="frdc-ds", + regex_filter=".*.jpg", + google_application_credentials=( + THIS_DIR / "frmodel-943e4feae446.json" + ).read_text(), + presign=False, + title="Source", +) +time.sleep(5) +# %% +print("Syncing Storage...") +# Then, we sync it so that all the images appear as annotation targets +dev_proj.sync_storage( + storage_type=dev_storage["type"], + storage_id=dev_storage["id"], +) +time.sleep(5) +# %% +print("Retrieving Tasks...") +prd_proj = prd_client.get_project(id=1) +prd_tasks = prd_proj.get_tasks() +dev_tasks = dev_proj.get_tasks() +# %% +# This step copies over the annotations from the production to the development +# This creates it as a "prediction" +print("Copying Annotations...") +for prd_task in prd_tasks: + # For each prod task, we find the corresponding (image) file name + prd_fn = prd_task["storage_filename"] + + # Then, we find the corresponding task in the development project + dev_tasks_matched = [ + t for t in dev_tasks if t["storage_filename"] == prd_fn + ] + + # Do some error handling + if len(dev_tasks_matched) == 0: + print(f"File not found in dev: {prd_fn}") + continue + if len(dev_tasks_matched) > 1: + print(f"Too many matches found in dev: {prd_fn}") + continue + + # Get the first match + dev_task = dev_tasks_matched[0] + + # Only get annotations by evening + prd_ann = [ + ann + for ann in prd_task["annotations"] + if "dev_evening" in ann["created_username"] + ][0] + + # Create the prediction using the result from production + dev_proj.create_prediction( + task_id=dev_task["id"], + result=prd_ann["result"], + model_version="API Testing Prediction", + ) + +print("Done!") diff --git a/tests/integration_tests/test_pipeline.py b/tests/integration_tests/test_pipeline.py index 97676dcb..3de3e6be 100644 --- a/tests/integration_tests/test_pipeline.py +++ b/tests/integration_tests/test_pipeline.py @@ -2,10 +2,8 @@ from pathlib import Path import lightning as pl -import numpy as np import pytest import torch -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from frdc.models.efficientnetb1 import ( EfficientNetB1MixMatchModule, @@ -33,23 +31,10 @@ def test_manual_segmentation_pipeline(model_fn, ds): val_ds=ds, batch_size=BATCH_SIZE, ) - - oe = OrdinalEncoder( - handle_unknown="use_encoded_value", - unknown_value=np.nan, - ) - oe.fit(np.array(ds.targets).reshape(-1, 1)) - n_classes = len(oe.categories_[0]) - - ss = StandardScaler() - ss.fit(ds.ar.reshape(-1, ds.ar.shape[-1])) - m = model_fn( in_channels=ds.ar.shape[-1], - n_classes=n_classes, lr=1e-3, - x_scaler=ss, - y_encoder=oe, + out_targets=ds.targets, frozen=True, ) @@ -93,4 +78,4 @@ def test_manual_segmentation_pipeline(model_fn, ds): # E.g. achieved via hash comparison. # This is because BatchNorm usually keeps running statistics # and reloading the model will reset them. - # We don't necessarily need to + # We don't necessarily need to check for this. diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index b1df238a..ac6b66a4 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -23,15 +23,14 @@ from frdc.train.frdc_datamodule import FRDCDataModule from frdc.utils.training import predict, plot_confusion_matrix from model_tests.utils import ( - val_preprocess, - FRDCDatasetStaticEval, - n_weak_strong_aug, - get_y_encoder, - get_x_scaler, - weak_aug, + const_weak_aug, + n_rand_weak_strong_aug, + rand_weak_aug, ) +# %% + # Uncomment this to run the W&B monitoring locally # import os # @@ -56,11 +55,14 @@ def main( # Prepare the dataset im_size = 255 - train_lab_ds = ds.chestnut_20201218(transform=weak_aug(im_size)) + train_lab_ds = ds.chestnut_20201218(transform=rand_weak_aug(im_size)) train_unl_ds = ds.chestnut_20201218.unlabelled( - transform=n_weak_strong_aug(im_size, unlabelled_factor) + transform=n_rand_weak_strong_aug(im_size, unlabelled_factor), + ) + val_ds = ds.chestnut_20210510_43m( + transform=const_weak_aug(im_size), + transform_scale=train_lab_ds.x_scaler, ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -95,15 +97,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - ss = get_x_scaler(train_lab_ds.ar_segments) - m = EfficientNetB1FixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - x_scaler=ss, - y_encoder=oe, frozen=True, ) @@ -116,15 +113,13 @@ def main( ) y_true, y_pred = predict( - ds=FRDCDatasetStaticEval( - "chestnut_nature_park", - "20210510", - "90deg43m85pct255deg", - transform=val_preprocess(im_size), + ds=ds.chestnut_20210510_43m.const_rotated( + transform=const_weak_aug(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index 1bf839f5..edae259a 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -16,17 +16,15 @@ ) from lightning.pytorch.loggers import WandbLogger +from frdc.load.dataset import FRDCConstRotatedDataset from frdc.load.preset import FRDCDatasetPreset as ds from frdc.models.efficientnetb1 import EfficientNetB1MixMatchModule from frdc.train.frdc_datamodule import FRDCDataModule from frdc.utils.training import predict, plot_confusion_matrix from model_tests.utils import ( - val_preprocess, - FRDCDatasetStaticEval, - n_strong_aug, - strong_aug, - get_y_encoder, - get_x_scaler, + const_weak_aug, + n_rand_strong_aug, + rand_strong_aug, ) @@ -46,11 +44,16 @@ def main( ): # Prepare the dataset im_size = 299 - train_lab_ds = ds.chestnut_20201218(transform=strong_aug(im_size)) + train_lab_ds = ds.chestnut_20201218( + transform=rand_strong_aug(im_size), + ) train_unl_ds = ds.chestnut_20201218.unlabelled( - transform=n_strong_aug(im_size, 2) + transform=n_rand_strong_aug(im_size, 2) + ) + val_ds = ds.chestnut_20210510_43m( + transform=const_weak_aug(im_size), + transform_scale=train_lab_ds.x_scaler, ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -85,15 +88,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - ss = get_x_scaler(train_lab_ds.ar_segments) - m = EfficientNetB1MixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - x_scaler=ss, - y_encoder=oe, frozen=True, ) @@ -106,15 +104,13 @@ def main( ) y_true, y_pred = predict( - ds=FRDCDatasetStaticEval( - "chestnut_nature_park", - "20210510", - "90deg43m85pct255deg", - transform=val_preprocess(im_size), + ds=ds.chestnut_20210510_43m.const_rotated( + transform=const_weak_aug(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index ac5b54ac..0cbb4ef5 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -2,10 +2,7 @@ from pathlib import Path -import numpy as np import torch -from sklearn.preprocessing import OrdinalEncoder, StandardScaler -from torch import rot90 from torchvision.transforms import RandomVerticalFlip from torchvision.transforms.v2 import ( Compose, @@ -19,86 +16,40 @@ Resize, ) from torchvision.transforms.v2 import RandomHorizontalFlip -from torchvision.transforms.v2.functional import hflip - -from frdc.load.dataset import FRDCDataset THIS_DIR = Path(__file__).parent BANDS = ["NB", "NG", "NR", "RE", "NIR"] -class FRDCDatasetStaticEval(FRDCDataset): - def __len__(self): - """Assume that the dataset is 8x larger than it actually is. - - There are 8 possible orientations for each image. - 1. As-is - 2, 3, 4. Rotated 90, 180, 270 degrees - 5. Horizontally flipped - 6, 7, 8. Horizontally flipped and rotated 90, 180, 270 degrees - """ - return super().__len__() * 8 - - def __getitem__(self, idx): - """Alter the getitem method to implement the logic above.""" - x, y = super().__getitem__(int(idx // 8)) - assert x.ndim == 3, "x must be a 3D tensor" - x_ = None - if idx % 8 == 0: - x_ = x - elif idx % 8 == 1: - x_ = rot90(x, 1, (1, 2)) - elif idx % 8 == 2: - x_ = rot90(x, 2, (1, 2)) - elif idx % 8 == 3: - x_ = rot90(x, 3, (1, 2)) - elif idx % 8 == 4: - x_ = hflip(x) - elif idx % 8 == 5: - x_ = hflip(rot90(x, 1, (1, 2))) - elif idx % 8 == 6: - x_ = hflip(rot90(x, 2, (1, 2))) - elif idx % 8 == 7: - x_ = hflip(rot90(x, 3, (1, 2))) - - return x_, y - - -def val_preprocess(size: int): - return lambda x: Compose( - [ - ToImage(), - ToDtype(torch.float32, scale=True), - Resize(size, antialias=True), - CenterCrop(size), - ] - )(x) +def n_times(f, n: int): + return lambda x: [f(x) for _ in range(n)] -def n_weak_aug(size, n_aug: int = 2): - return lambda x: ( - [weak_aug(size)(x) for _ in range(n_aug)] if n_aug > 0 else None - ) +def n_rand_weak_aug(size, n_aug: int = 2): + return n_times(rand_weak_aug(size), n_aug) -def n_strong_aug(size, n_aug: int = 2): - return lambda x: ( - [strong_aug(size)(x) for _ in range(n_aug)] if n_aug > 0 else None - ) +def n_rand_strong_aug(size, n_aug: int = 2): + return n_times(rand_strong_aug(size), n_aug) -def n_weak_strong_aug(size, n_aug: int = 2): +def n_rand_weak_strong_aug(size, n_aug: int = 2): def f(x): - x_weak = n_weak_aug(size, n_aug)(x) - x_strong = n_strong_aug(size, n_aug)(x) - return list(zip(*[x_weak, x_strong])) if n_aug > 0 else None + # x_weak = [weak_0, weak_1, ..., weak_n] + x_weak = n_rand_weak_aug(size, n_aug)(x) + # x_strong = [strong_0, strong_1, ..., strong_n] + x_strong = n_rand_strong_aug(size, n_aug)(x) + # x_paired = [(weak_0, strong_0), (weak_1, strong_1), + # ..., (weak_n, strong_n)] + x_paired = list(zip(*[x_weak, x_strong])) + return x_paired return f -def weak_aug(size: int): - return lambda x: Compose( +def rand_weak_aug(size: int): + return Compose( [ ToImage(), ToDtype(torch.float32, scale=True), @@ -108,35 +59,29 @@ def weak_aug(size: int): RandomVerticalFlip(), RandomApply([RandomRotation((90, 90))], p=0.5), ] - )(x) + ) -def strong_aug(size: int): - return lambda x: Compose( +def const_weak_aug(size: int): + return Compose( [ ToImage(), ToDtype(torch.float32, scale=True), Resize(size, antialias=True), - RandomCrop(size, pad_if_needed=False), # Strong - RandomHorizontalFlip(), - RandomVerticalFlip(), - RandomApply([RandomRotation((90, 90))], p=0.5), + CenterCrop(size), ] - )(x) - - -def get_y_encoder(targets): - oe = OrdinalEncoder( - handle_unknown="use_encoded_value", - unknown_value=np.nan, ) - oe.fit(np.array(targets).reshape(-1, 1)) - return oe -def get_x_scaler(segments): - ss = StandardScaler() - ss.fit( - np.concatenate([segm.reshape(-1, segm.shape[-1]) for segm in segments]) +def rand_strong_aug(size: int): + return Compose( + [ + ToImage(), + ToDtype(torch.float32, scale=True), + Resize(size, antialias=True), + RandomCrop(size, pad_if_needed=False), + RandomHorizontalFlip(), + RandomVerticalFlip(), + RandomApply([RandomRotation((90, 90))], p=0.5), + ] ) - return ss