diff --git a/poetry.lock b/poetry.lock index 01cbe0a..cd608a5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -230,17 +230,6 @@ files = [ {file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"}, ] -[[package]] -name = "cfgv" -version = "3.4.0" -description = "Validate configuration and produce human readable error messages." -optional = false -python-versions = ">=3.8" -files = [ - {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, - {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, -] - [[package]] name = "charset-normalizer" version = "3.3.2" @@ -443,17 +432,6 @@ files = [ docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] tests = ["pytest", "pytest-cov", "pytest-xdist"] -[[package]] -name = "distlib" -version = "0.3.8" -description = "Distribution utilities" -optional = false -python-versions = "*" -files = [ - {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, - {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, -] - [[package]] name = "docker-pycreds" version = "0.4.0" @@ -922,20 +900,6 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] -[[package]] -name = "identify" -version = "2.5.33" -description = "File identification library for Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "identify-2.5.33-py2.py3-none-any.whl", hash = "sha256:d40ce5fcd762817627670da8a7d8d8e65f24342d14539c59488dc603bf662e34"}, - {file = "identify-2.5.33.tar.gz", hash = "sha256:161558f9fe4559e1557e1bff323e8631f6a0e4837f7497767c1782832f16b62d"}, -] - -[package.extras] -license = ["ukkonen"] - [[package]] name = "idna" version = "3.6" @@ -1674,20 +1638,6 @@ doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9. extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] -[[package]] -name = "nodeenv" -version = "1.8.0" -description = "Node.js virtual environment builder" -optional = false -python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" -files = [ - {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, - {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, -] - -[package.dependencies] -setuptools = "*" - [[package]] name = "numpy" version = "1.26.2" @@ -2094,24 +2044,6 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] -[[package]] -name = "pre-commit" -version = "3.6.0" -description = "A framework for managing and maintaining multi-language pre-commit hooks." -optional = false -python-versions = ">=3.9" -files = [ - {file = "pre_commit-3.6.0-py2.py3-none-any.whl", hash = "sha256:c255039ef399049a5544b6ce13d135caba8f2c28c3b4033277a788f434308376"}, - {file = "pre_commit-3.6.0.tar.gz", hash = "sha256:d30bad9abf165f7785c15a21a1f46da7d0677cb00ee7ff4c579fd38922efe15d"}, -] - -[package.dependencies] -cfgv = ">=2.0.0" -identify = ">=1.0.0" -nodeenv = ">=0.11.1" -pyyaml = ">=5.1" -virtualenv = ">=20.10.0" - [[package]] name = "protobuf" version = "4.25.1" @@ -3161,26 +3093,6 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] -[[package]] -name = "virtualenv" -version = "20.25.0" -description = "Virtual Python Environment builder" -optional = false -python-versions = ">=3.7" -files = [ - {file = "virtualenv-20.25.0-py3-none-any.whl", hash = "sha256:4238949c5ffe6876362d9c0180fc6c3a824a7b12b80604eeb8085f2ed7460de3"}, - {file = "virtualenv-20.25.0.tar.gz", hash = "sha256:bf51c0d9c7dd63ea8e44086fa1e4fb1093a31e963b86959257378aef020e1f1b"}, -] - -[package.dependencies] -distlib = ">=0.3.7,<1" -filelock = ">=3.12.2,<4" -platformdirs = ">=3.9.1,<5" - -[package.extras] -docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] - [[package]] name = "wandb" version = "0.16.1" @@ -3460,4 +3372,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "16b11478ea4926d0df25a68595bb84c19b171107a812edae90838c926ed060ac" +content-hash = "ac048851b4e31289fa53600547920b790209eec3456fd642e0b0666ead39bd21" diff --git a/pyproject.toml b/pyproject.toml index 3864acf..65c7094 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ lightning = "^2.3.0" [tool.poetry.group.dev.dependencies] pytest = "^7.4.2" -pre-commit = "^3.5.0" black = "^23.10.0" wandb = "^0.16.0" plotly = "^5.22.0" diff --git a/src/frdc/load/dataset.py b/src/frdc/load/dataset.py index ca18f44..30041d4 100644 --- a/src/frdc/load/dataset.py +++ b/src/frdc/load/dataset.py @@ -8,11 +8,14 @@ import numpy as np import pandas as pd +import torch from PIL import Image from sklearn.preprocessing import StandardScaler from torch import rot90 from torch.utils.data import Dataset, ConcatDataset +from torchvision.transforms.v2 import Compose from torchvision.transforms.v2.functional import hflip +from torchvision.tv_tensors import Image as ImageTensor from frdc.conf import BAND_CONFIG, LABEL_STUDIO_CLIENT from frdc.load import gcs @@ -21,11 +24,51 @@ extract_segments_from_bounds, extract_segments_from_polybounds, ) -from frdc.utils import Rect +from frdc.utils.utils import Rect, flatten_nested, map_nested logger = logging.getLogger(__name__) +class ImageStandardScaler(StandardScaler): + def fit(self, X, y=None, sample_weight=None): + X = X.reshape(X.shape[0], -1) + return super().fit(X, y, sample_weight) + + def transform(self, X, copy=None): + shape = X.shape + X = X.reshape(shape[0], -1) + X = torch.nan_to_num(X, nan=0) + X = super().transform(X, copy).reshape(*shape) + X = torch.tensor(X) + return X.to(torch.float32) + + def transform_one(self, X, copy=None): + shape = X.shape + X = X.reshape(1, -1) + return self.transform(X, copy).reshape(shape) + + def inverse_transform(self, X, y=None, **fit_params): + shape = X.shape + X = X.reshape(shape[0], -1) + return ( + super() + .inverse_transform(X, y, **fit_params) + .reshape(shape[0], *shape[1:]) + ) + + def fit_nested(self, X): + # Adapted method of `fit` to handle nested lists/tuples + X = torch.stack(flatten_nested(X, type_list=(list, tuple))) + self.fit(X) + return self + + def transform_nested(self, X): + # Adapted method of `transform` to handle nested lists/tuples + # This preserves the nested structure of the input by treating every + # atom as a single entity and transforming as-is. + return map_nested(X, self.transform_one, ImageTensor, (list, tuple)) + + @dataclass class FRDCDataset(Dataset): def __init__( @@ -33,8 +76,7 @@ def __init__( site: str, date: str, version: str | None, - transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool | StandardScaler = True, + transform: Compose = lambda x: x, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -58,9 +100,6 @@ def __init__( date: The date of the dataset, e.g. "20201218". version: The version of the dataset, e.g. "183deg". transform: The transform to apply to each segment. - transform_scale: Whether to scale the data. If True, it will fit - a StandardScaler to the data. If a StandardScaler is passed, - it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to each label. use_legacy_bounds: Whether to use the legacy bounds.csv file. This will automatically be set to True if LABEL_STUDIO_CLIENT @@ -100,33 +139,6 @@ def __init__( self.transform = transform self.target_transform = target_transform - if transform_scale is True: - self.x_scaler = StandardScaler() - self.x_scaler.fit( - np.concatenate( - [ - # Segments: [H x W x C] -> [H*W, C] - # Reshaping is necessary for StandardScaler - segm.reshape(-1, segm.shape[-1]) - for segm in self.ar_segments - ] - ) - ) - self.transform = lambda x: transform( - self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( - x.shape - ) - ) - elif isinstance(transform_scale, StandardScaler): - self.x_scaler = transform_scale - self.transform = lambda x: transform( - self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( - x.shape - ) - ) - else: - self.x_scaler = None - def __len__(self): return len(self.ar_segments) @@ -159,7 +171,7 @@ def _get_ar_bands_as_dict( get all bands in BAND_CONFIG. Examples: - >>> get_ar_bands_as_dict(['WB', 'WG', 'WR']]) + >>> self._get_ar_bands_as_dict(['WB', 'WG', 'WR']]) Returns @@ -208,7 +220,7 @@ def _get_ar_bands( get all bands in BAND_CONFIG. Examples - >>> get_ar_bands(['WB', 'WG', 'WR']) + >>> self._get_ar_bands(['WB', 'WG', 'WR']) Returns diff --git a/src/frdc/load/preset.py b/src/frdc/load/preset.py index 2f39d4a..b349152 100644 --- a/src/frdc/load/preset.py +++ b/src/frdc/load/preset.py @@ -6,7 +6,6 @@ import numpy as np import torch -from sklearn.preprocessing import StandardScaler from torchvision.transforms.v2 import ( Compose, ToImage, @@ -52,8 +51,7 @@ class FRDCDatasetPartial: def __call__( self, - transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool | StandardScaler = True, + transform: Compose = lambda x: x, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -63,18 +61,14 @@ def __call__( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. If True, it will fit - a StandardScaler to the data. If a StandardScaler is passed, - it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. polycrop_value: The value to use for polycrop. """ return self.labelled( - transform, - transform_scale, - target_transform, + transform=transform, + target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, polycrop_value=polycrop_value, @@ -82,8 +76,7 @@ def __call__( def labelled( self, - transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool | StandardScaler = True, + transform: Compose = lambda x: x, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -93,9 +86,6 @@ def labelled( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. If True, it will fit - a StandardScaler to the data. If a StandardScaler is passed, - it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. @@ -106,7 +96,6 @@ def labelled( self.date, self.version, transform=transform, - transform_scale=transform_scale, target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, @@ -115,8 +104,7 @@ def labelled( def unlabelled( self, - transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool | StandardScaler = True, + transform: Compose = lambda x: x, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -131,9 +119,6 @@ def unlabelled( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. If True, it will fit - a StandardScaler to the data. If a StandardScaler is passed, - it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. @@ -144,7 +129,6 @@ def unlabelled( self.date, self.version, transform=transform, - transform_scale=transform_scale, target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, @@ -153,8 +137,7 @@ def unlabelled( def const_rotated( self, - transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool | StandardScaler = True, + transform: Compose = lambda x: x, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -169,9 +152,6 @@ def const_rotated( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. If True, it will fit - a StandardScaler to the data. If a StandardScaler is passed, - it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. @@ -182,7 +162,6 @@ def const_rotated( self.date, self.version, transform=transform, - transform_scale=transform_scale, target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, diff --git a/src/frdc/train/frdc_module.py b/src/frdc/train/frdc_module.py index 22582e9..f16886c 100644 --- a/src/frdc/train/frdc_module.py +++ b/src/frdc/train/frdc_module.py @@ -6,7 +6,7 @@ from sklearn.preprocessing import OrdinalEncoder from frdc.models.utils import save_unfrozen, load_checkpoint_lenient -from frdc.utils.utils import fn_recursive +from frdc.utils.utils import map_nested class FRDCModule(LightningModule): @@ -98,7 +98,7 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: # This function applies nan_to_num to all tensors in the list, # regardless of how deeply nested they are. - x_unl_trans = fn_recursive( + x_unl_trans = map_nested( x_unl, fn=lambda x: torch.nan_to_num(x[nan]), type_atom=torch.Tensor, diff --git a/src/frdc/train/utils.py b/src/frdc/train/utils.py index b42ce91..668e4d5 100644 --- a/src/frdc/train/utils.py +++ b/src/frdc/train/utils.py @@ -3,8 +3,6 @@ import wandb from sklearn.preprocessing import StandardScaler, OrdinalEncoder -from frdc.utils.utils import fn_recursive - def mix_up( x: torch.Tensor, diff --git a/src/frdc/utils/utils.py b/src/frdc/utils/utils.py index 8402eb1..539ed5f 100644 --- a/src/frdc/utils/utils.py +++ b/src/frdc/utils/utils.py @@ -3,7 +3,7 @@ Rect = namedtuple("Rect", ["x0", "y0", "x1", "y1"]) -def fn_recursive(x, fn, type_atom, type_list): +def map_nested(x, fn, type_atom, type_list): """Recursively applies a function to the data preserving the structure Args: @@ -15,6 +15,15 @@ def fn_recursive(x, fn, type_atom, type_list): if isinstance(x, type_atom): return fn(x) elif isinstance(x, type_list): - return [fn_recursive(item, fn, type_atom, type_list) for item in x] + return [map_nested(item, fn, type_atom, type_list) for item in x] else: return x + + +def flatten_nested(x, type_list): + """Flattens a nested list""" + if not isinstance(x, type_list): + return [x] + return [ + item for sub in x for item in flatten_nested(sub, type_list=type_list) + ] diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index ac6b66a..831d611 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -16,6 +16,7 @@ ) from lightning.pytorch.loggers import WandbLogger +from frdc.load.dataset import ImageStandardScaler from frdc.load.preset import FRDCDatasetPreset as ds from frdc.models.efficientnetb1 import ( EfficientNetB1FixMatchModule, @@ -56,13 +57,17 @@ def main( # Prepare the dataset im_size = 255 train_lab_ds = ds.chestnut_20201218(transform=rand_weak_aug(im_size)) + iss = ImageStandardScaler().fit_nested(train_lab_ds[:][0]) + train_lab_ds.transform.transforms.append(iss.transform_nested) train_unl_ds = ds.chestnut_20201218.unlabelled( transform=n_rand_weak_strong_aug(im_size, unlabelled_factor), ) + + train_unl_ds.transform.transforms.append(iss.transform_nested) val_ds = ds.chestnut_20210510_43m( transform=const_weak_aug(im_size), - transform_scale=train_lab_ds.x_scaler, ) + val_ds.transform.transforms.append(iss.transform_nested) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -111,14 +116,11 @@ def main( f"# Chestnut Nature Park (Dec 2020 vs May 2021) FixMatch\n" f"- Results: [WandB Report]({wandb.run.get_url()})\n" ) - - y_true, y_pred = predict( - ds=ds.chestnut_20210510_43m.const_rotated( - transform=const_weak_aug(im_size), - transform_scale=train_lab_ds.x_scaler, - ), - model=m, + test_ds = ds.chestnut_20210510_43m.const_rotated( + transform=const_weak_aug(im_size), ) + test_ds.transform.transforms.append(iss.transform_nested) + y_true, y_pred = predict(ds=test_ds, model=m) fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index 0cbb4ef..5399039 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -35,17 +35,18 @@ def n_rand_strong_aug(size, n_aug: int = 2): def n_rand_weak_strong_aug(size, n_aug: int = 2): - def f(x): - # x_weak = [weak_0, weak_1, ..., weak_n] - x_weak = n_rand_weak_aug(size, n_aug)(x) - # x_strong = [strong_0, strong_1, ..., strong_n] - x_strong = n_rand_strong_aug(size, n_aug)(x) - # x_paired = [(weak_0, strong_0), (weak_1, strong_1), - # ..., (weak_n, strong_n)] - x_paired = list(zip(*[x_weak, x_strong])) - return x_paired - - return f + return Compose( + [ + lambda x: list( + zip( + *[ + n_rand_weak_aug(size, n_aug)(x), + n_rand_strong_aug(size, n_aug)(x), + ] + ) + ) + ] + ) def rand_weak_aug(size: int):