From dc3a549ae5b2ce2f33c4b2c922ddb7a594bae7a8 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 16 Nov 2023 19:52:15 +0100 Subject: [PATCH] feat: extend hyperparameters of data modules --- rul_datasets/adaption.py | 12 ++++----- rul_datasets/baseline.py | 5 ++-- rul_datasets/core.py | 15 ++++++----- rul_datasets/reader/abstract.py | 12 ++++++--- rul_datasets/reader/cmapss.py | 8 ++++-- rul_datasets/reader/dummy.py | 9 +++++-- rul_datasets/reader/femto.py | 8 ++++-- rul_datasets/reader/xjtu_sy.py | 8 ++++-- rul_datasets/ssl.py | 9 +------ tests/reader/test_abstract.py | 46 ++++++++++++++++++++------------- tests/test_adaption.py | 25 +++++------------- tests/test_baseline.py | 2 +- tests/test_core.py | 16 +++++++++--- 13 files changed, 98 insertions(+), 77 deletions(-) diff --git a/rul_datasets/adaption.py b/rul_datasets/adaption.py index 737d8d8..9425ddf 100644 --- a/rul_datasets/adaption.py +++ b/rul_datasets/adaption.py @@ -57,6 +57,7 @@ def __init__( source: The data module of the labeled source domain. target: The data module of the unlabeled target domain. paired_val: Whether to include paired data in validation. + inductive: Whether to use the target test set for training. """ super().__init__() @@ -73,13 +74,10 @@ def __init__( self.save_hyperparameters( { - "fd_source": self.source.reader.fd, - "fd_target": self.target.reader.fd, - "batch_size": self.batch_size, - "window_size": self.source.reader.window_size, - "max_rul": self.source.reader.max_rul, - "percent_broken": self.target.reader.percent_broken, - "percent_fail_runs": self.target.reader.percent_fail_runs, + "source": self.source.hparams, + "target": self.target.hparams, + "paired_val": self.paired_val, + "inductive": self.inductive, } ) diff --git a/rul_datasets/baseline.py b/rul_datasets/baseline.py index c394d69..1f73956 100644 --- a/rul_datasets/baseline.py +++ b/rul_datasets/baseline.py @@ -50,15 +50,14 @@ def __init__(self, data_module: RulDataModule) -> None: super().__init__() self.data_module = data_module - hparams = self.data_module.hparams - self.save_hyperparameters(hparams) + self.save_hyperparameters(self.data_module.hparams) self.subsets = {} for fd in self.data_module.fds: self.subsets[fd] = self._get_fd(fd) def _get_fd(self, fd): - if fd == self.hparams["fd"]: + if fd == self.data_module.reader.fd: dm = self.data_module else: loader = deepcopy(self.data_module.reader) diff --git a/rul_datasets/core.py b/rul_datasets/core.py index 12aad64..e5a2300 100644 --- a/rul_datasets/core.py +++ b/rul_datasets/core.py @@ -1,7 +1,6 @@ """Basic data modules for experiments involving only a single subset of any RUL dataset. """ -from copy import deepcopy from typing import Dict, List, Optional, Tuple, Any, Callable import numpy as np @@ -105,12 +104,14 @@ def __init__( "to set a window size for re-windowing." ) - hparams = deepcopy(self.reader.hparams) - hparams["batch_size"] = self.batch_size - hparams["feature_extractor"] = ( - str(self.feature_extractor) if self.feature_extractor else None - ) - hparams["window_size"] = self.window_size or hparams["window_size"] + hparams = { + "reader": self.reader.hparams, + "batch_size": self.batch_size, + "feature_extractor": ( + str(self.feature_extractor) if self.feature_extractor else None + ), + "window_size": self.window_size, + } self.save_hyperparameters(hparams) @property diff --git a/rul_datasets/reader/abstract.py b/rul_datasets/reader/abstract.py index ba7031b..71ca0cc 100644 --- a/rul_datasets/reader/abstract.py +++ b/rul_datasets/reader/abstract.py @@ -92,10 +92,10 @@ def __init__( @property def hparams(self) -> Dict[str, Any]: - """A dictionary containing all input arguments of the constructor. This - information is used by the data modules to log their hyperparameters in - PyTorch Lightning.""" + """All information logged by the data modules as hyperparameters in PyTorch + Lightning.""" return { + "dataset": self.dataset_name, "fd": self.fd, "window_size": self.window_size, "max_rul": self.max_rul, @@ -105,6 +105,12 @@ def hparams(self) -> Dict[str, Any]: "truncate_degraded_only": self.truncate_degraded_only, } + @property + @abc.abstractmethod + def dataset_name(self) -> str: + """Name of the dataset.""" + raise NotImplementedError + @property @abc.abstractmethod def fds(self) -> List[int]: diff --git a/rul_datasets/reader/cmapss.py b/rul_datasets/reader/cmapss.py index 82328cf..3b8e6ef 100644 --- a/rul_datasets/reader/cmapss.py +++ b/rul_datasets/reader/cmapss.py @@ -23,9 +23,9 @@ class CmapssReader(AbstractReader): """ This reader represents the NASA CMAPSS Turbofan Degradation dataset. Each of its - four sub-datasets contain a training and a test split. Upon first usage, + four sub-datasets contains a training and a test split. Upon first usage, the training split will be further divided into a development and a validation - split. 20% of the original training split are reserved for validation. + split. 20% of the original training split is reserved for validation. The features are provided as sliding windows over each time series in the dataset. The label of a window is the label of its last time step. The RUL labels @@ -128,6 +128,10 @@ def __init__( self.feature_select = feature_select self.operation_condition_aware_scaling = operation_condition_aware_scaling + @property + def dataset_name(self) -> str: + return "cmapss" + @property def fds(self) -> List[int]: """Indices of available sub-datasets.""" diff --git a/rul_datasets/reader/dummy.py b/rul_datasets/reader/dummy.py index a205e7d..b4eb422 100644 --- a/rul_datasets/reader/dummy.py +++ b/rul_datasets/reader/dummy.py @@ -46,6 +46,7 @@ class DummyReader(AbstractReader): """ _FDS = [1, 2] + _DEFAULT_WINDOW_SIZE = 10 _NOISE_FACTOR = {1: 0.01, 2: 0.02} _OFFSET = {1: 0.5, 2: 0.75} @@ -62,12 +63,12 @@ def __init__( truncate_degraded_only: bool = False, ): """ - Create a new dummy reader for one of the two sub-datasets. The maximun RUL + Create a new dummy reader for one of the two sub-datasets. The maximum RUL value is set to 50 by default. Please be aware that changing this value will lead to different features, too, as they are calculated based on the RUL values. - For more information about using readers refer to the [reader] + For more information about using readers, refer to the [reader] [rul_datasets.reader] module page. Args: @@ -94,6 +95,10 @@ def __init__( scaler = preprocessing.MinMaxScaler(feature_range=(-1, 1)) self.scaler = scaling.fit_scaler(features, scaler) + @property + def dataset_name(self) -> str: + return "xjtu-sy" + @property def fds(self) -> List[int]: """Indices of available sub-datasets.""" diff --git a/rul_datasets/reader/femto.py b/rul_datasets/reader/femto.py index 5152d7f..b6ec809 100644 --- a/rul_datasets/reader/femto.py +++ b/rul_datasets/reader/femto.py @@ -25,9 +25,9 @@ class FemtoReader(AbstractReader): """ This reader represents the FEMTO (PRONOSTIA) Bearing dataset. Each of its three - sub-datasets contain a training and a test split. By default, the reader + sub-datasets contains a training and a test split. By default, the reader constructs a validation split for sub-datasets 1 and 2 each by taking the first - run of the test split. For sub-dataset 3 the second training run is used for + run of the test split. For sub-dataset 3, the second training run is used for validation because only one test run is available. The remaining training data is denoted as the development split. This run to split assignment can be overridden by setting `run_split_dist`. @@ -130,6 +130,10 @@ def __init__( self._preparator = FemtoPreparator(self.fd, self._FEMTO_ROOT, run_split_dist) + @property + def dataset_name(self) -> str: + return "femto" + @property def fds(self) -> List[int]: """Indices of available sub-datasets.""" diff --git a/rul_datasets/reader/xjtu_sy.py b/rul_datasets/reader/xjtu_sy.py index f0dd976..f31b621 100644 --- a/rul_datasets/reader/xjtu_sy.py +++ b/rul_datasets/reader/xjtu_sy.py @@ -85,7 +85,7 @@ def __init__( constant. The `norm_rul` argument can then be used to scale the RUL of each run between zero and one. - For more information about using readers refer to the [reader] + For more information about using readers, refer to the [reader] [rul_datasets.reader] module page. Args: @@ -114,7 +114,7 @@ def __init__( if (first_time_to_predict is not None) and (max_rul is not None): raise ValueError( - "FemtoReader cannot use 'first_time_to_predict' " + "XjtuSyReader cannot use 'first_time_to_predict' " "and 'max_rul' in conjunction." ) @@ -123,6 +123,10 @@ def __init__( self._preparator = XjtuSyPreparator(self.fd, self._XJTU_SY_ROOT, run_split_dist) + @property + def dataset_name(self) -> str: + return "xjtu-sy" + @property def fds(self) -> List[int]: """Indices of available sub-datasets.""" diff --git a/rul_datasets/ssl.py b/rul_datasets/ssl.py index d94c47d..68192b5 100644 --- a/rul_datasets/ssl.py +++ b/rul_datasets/ssl.py @@ -49,14 +49,7 @@ def __init__(self, labeled: RulDataModule, unlabeled: RulDataModule) -> None: self._check_compatibility() self.save_hyperparameters( - { - "fd": self.labeled.reader.fd, - "batch_size": self.batch_size, - "window_size": self.labeled.reader.window_size, - "max_rul": self.labeled.reader.max_rul, - "percent_broken_unlabeled": self.unlabeled.reader.percent_broken, - "percent_fail_runs_labeled": self.labeled.reader.percent_fail_runs, - } + {"labeled": self.labeled.hparams, "unlabeled": self.unlabeled.hparams} ) def _check_compatibility(self) -> None: diff --git a/tests/reader/test_abstract.py b/tests/reader/test_abstract.py index 3d3d1ee..f121fb6 100644 --- a/tests/reader/test_abstract.py +++ b/tests/reader/test_abstract.py @@ -7,7 +7,7 @@ from rul_datasets import reader -class DummyReader(reader.AbstractReader): +class DummyAbstractReader(reader.AbstractReader): fd: int window_size: int max_rul: int @@ -17,6 +17,10 @@ class DummyReader(reader.AbstractReader): _NUM_TRAIN_RUNS = {1: 100} + @property + def dataset_name(self) -> str: + return "dummy_abstract" + @property def fds(self): return [1] @@ -36,17 +40,21 @@ def load_complete_split( class TestAbstractLoader: @mock.patch("rul_datasets.reader.truncating.truncate_runs", return_value=([], [])) def test_truncation_dev_split(self, mock_truncate_runs): - this = DummyReader(1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8) + this = DummyAbstractReader( + 1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8 + ) this.load_split("dev") mock_truncate_runs.assert_called_with([], [], 0.2, 0.8, False) @mock.patch("rul_datasets.reader.truncating.truncate_runs", return_value=([], [])) def test_truncation_val_split(self, mock_truncate_runs): - this = DummyReader(1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8) + this = DummyAbstractReader( + 1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8 + ) this.load_split("val") mock_truncate_runs.assert_not_called() - this = DummyReader( + this = DummyAbstractReader( 1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8, truncate_val=True ) this.load_split("val") @@ -54,20 +62,22 @@ def test_truncation_val_split(self, mock_truncate_runs): @mock.patch("rul_datasets.reader.truncating.truncate_runs", return_value=([], [])) def test_truncation_test_split(self, mock_truncate_runs): - this = DummyReader(1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8) + this = DummyAbstractReader( + 1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8 + ) this.load_split("val") mock_truncate_runs.assert_not_called() def test_check_compatibility(self): - this = DummyReader(1, 30, 125) - this.check_compatibility(DummyReader(1, 30, 125)) + this = DummyAbstractReader(1, 30, 125) + this.check_compatibility(DummyAbstractReader(1, 30, 125)) with pytest.raises(ValueError): - this.check_compatibility(DummyReader(1, 20, 125)) + this.check_compatibility(DummyAbstractReader(1, 20, 125)) with pytest.raises(ValueError): - this.check_compatibility(DummyReader(1, 30, 120)) + this.check_compatibility(DummyAbstractReader(1, 30, 120)) def test_get_compatible_same(self): - this = DummyReader(1, 30, 125) + this = DummyAbstractReader(1, 30, 125) other = this.get_compatible() this.check_compatibility(other) assert other is not this @@ -79,7 +89,7 @@ def test_get_compatible_same(self): assert this.truncate_val == other.truncate_val def test_get_compatible_different(self): - this = DummyReader(1, 30, 125) + this = DummyAbstractReader(1, 30, 125) other = this.get_compatible(2, 0.2, 0.8, False) this.check_compatibility(other) assert other is not this @@ -92,21 +102,21 @@ def test_get_compatible_different(self): assert not other.truncate_val def test_get_complement_percentage(self): - this = DummyReader(1, 30, 125, percent_fail_runs=0.8) + this = DummyAbstractReader(1, 30, 125, percent_fail_runs=0.8) other = this.get_complement(0.8, False) assert other.percent_fail_runs == list(range(80, 100)) assert 0.8 == other.percent_broken assert not other.truncate_val def test_get_complement_idx(self): - this = DummyReader(1, 30, 125, percent_fail_runs=list(range(80))) + this = DummyAbstractReader(1, 30, 125, percent_fail_runs=list(range(80))) other = this.get_complement(0.8, False) assert other.percent_fail_runs == list(range(80, 100)) assert 0.8 == other.percent_broken assert not other.truncate_val def test_get_complement_empty(self): - this = DummyReader(1, 30, 125) # Uses all runs + this = DummyAbstractReader(1, 30, 125) # Uses all runs other = this.get_complement(0.8, False) assert not other.percent_fail_runs # Complement is empty assert 0.8 == other.percent_broken @@ -125,8 +135,8 @@ def test_get_complement_empty(self): ], ) def test_is_mutually_exclusive(self, runs_this, runs_other, success): - this = DummyReader(1, percent_fail_runs=runs_this) - other = DummyReader(1, percent_fail_runs=runs_other) + this = DummyAbstractReader(1, percent_fail_runs=runs_this) + other = DummyAbstractReader(1, percent_fail_runs=runs_other) assert this.is_mutually_exclusive(other) == success assert other.is_mutually_exclusive(this) == success @@ -136,7 +146,7 @@ def test_is_mutually_exclusive(self, runs_this, runs_other, success): [("override", 30, 30), ("min", 15, 15), ("none", 30, 15)], ) def test_consolidate_window_size(self, mode, expected_this, expected_other): - this = DummyReader(1, window_size=30) + this = DummyAbstractReader(1, window_size=30) other = this.get_compatible(2, consolidate_window_size=mode) assert this.window_size == expected_this @@ -157,7 +167,7 @@ def test_consolidate_window_size(self, mode, expected_this, expected_other): ) @mock.patch("rul_datasets.reader.truncating.truncate_runs", return_value=([], [])) def test_alias(self, mock_truncate_runs, split, alias, truncate_val, exp_truncated): - this = DummyReader(1, truncate_val=truncate_val) + this = DummyAbstractReader(1, truncate_val=truncate_val) this.load_complete_split = mock.Mock(wraps=this.load_complete_split) this.load_split(split, alias) diff --git a/tests/test_adaption.py b/tests/test_adaption.py index 0eee5db..3dac59d 100644 --- a/tests/test_adaption.py +++ b/tests/test_adaption.py @@ -8,7 +8,7 @@ from torch.utils.data import RandomSampler, TensorDataset import rul_datasets -from rul_datasets import adaption, core +from rul_datasets import adaption, core, CmapssReader from rul_datasets.reader import DummyReader from tests.templates import PretrainingDataModuleTemplate @@ -16,16 +16,12 @@ class TestDomainAdaptionDataModule(unittest.TestCase): def setUp(self): source_mock_runs = [np.random.randn(16, 14, 1)] * 3, [np.random.rand(16)] * 3 - self.source_loader = mock.MagicMock(name="CMAPSSLoader") + self.source_loader = mock.MagicMock(CmapssReader) self.source_loader.fd = 3 self.source_loader.percent_fail_runs = None self.source_loader.percent_broken = None self.source_loader.window_size = 1 self.source_loader.max_rul = 125 - self.source_loader.hparams = { - "fd": self.source_loader.fd, - "window_size": self.source_loader.window_size, - } self.source_loader.load_split.return_value = source_mock_runs self.source_data = mock.MagicMock(rul_datasets.RulDataModule) self.source_data.reader = self.source_loader @@ -33,16 +29,12 @@ def setUp(self): self.source_data.to_dataset.return_value = TensorDataset(torch.zeros(1)) target_mock_runs = [np.random.randn(16, 14, 1)] * 2, [np.random.rand(16)] * 2 - self.target_loader = mock.MagicMock(name="CMAPSSLoader") + self.target_loader = mock.MagicMock(CmapssReader) self.target_loader.fd = 1 self.target_loader.percent_fail_runs = 0.8 self.target_loader.percent_broken = 0.8 self.target_loader.window_size = 1 self.target_loader.max_rul = 125 - self.target_loader.hparams = { - "fd": self.target_loader.fd, - "window_size": self.target_loader.window_size, - } self.target_loader.load_split.return_value = target_mock_runs self.target_data = mock.MagicMock(rul_datasets.RulDataModule) self.target_data.reader = self.target_loader @@ -143,13 +135,10 @@ def test_truncated_loader(self): def test_hparams(self): expected_hparams = { - "fd_source": 3, - "fd_target": 1, - "batch_size": 16, - "window_size": 1, - "max_rul": 125, - "percent_broken": 0.8, - "percent_fail_runs": 0.8, + "source": self.dataset.source.hparams, + "target": self.dataset.target.hparams, + "paired_val": False, + "inductive": False, } self.assertDictEqual(expected_hparams, self.dataset.hparams) diff --git a/tests/test_baseline.py b/tests/test_baseline.py index bb4814c..f53c5d8 100644 --- a/tests/test_baseline.py +++ b/tests/test_baseline.py @@ -32,7 +32,7 @@ def test_test_sets_created_correctly(self): for fd in self.mock_loader.fds: self.assertIn(fd, self.dataset.subsets) self.assertEqual(fd, self.dataset.subsets[fd].reader.fd) - if fd == self.dataset.hparams["fd"]: + if fd == self.dataset.data_module.reader.fd: self.assertIs(self.dataset.data_module, self.dataset.subsets[fd]) else: self.assertIsNone(self.dataset.subsets[fd].reader.percent_fail_runs) diff --git a/tests/test_core.py b/tests/test_core.py index 1459035..4e8c19a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -33,9 +33,9 @@ def test_created_correctly(self, mock_loader): assert mock_loader is dataset.reader assert 16 == dataset.batch_size assert dataset.hparams == { - "test": 0, + "reader": {"test": 0, "window_size": 30}, "batch_size": 16, - "window_size": mock_loader.hparams["window_size"], + "window_size": None, "feature_extractor": None, } @@ -49,9 +49,9 @@ def test_created_correctly_with_feature_extractor(self, mock_loader, window_size assert mock_loader is dataset.reader assert 16 == dataset.batch_size assert dataset.hparams == { - "test": 0, + "reader": {"test": 0, "window_size": 30}, "batch_size": 16, - "window_size": window_size or mock_loader.hparams["window_size"], + "window_size": window_size, "feature_extractor": str(fe), } @@ -262,6 +262,10 @@ def __init__(self, length): ), } + @property + def dataset_name(self) -> str: + return "dummy_rul" + @property def fds(self): return [1] @@ -308,6 +312,10 @@ class DummyRulShortRuns(reader.AbstractReader): ), } + @property + def dataset_name(self) -> str: + return "dummy_rul_short_runs" + @property def fds(self): return [1]