Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhengyanZhu committed Mar 20, 2024
1 parent 494052f commit 757d9b6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
15 changes: 9 additions & 6 deletions rul_datasets/reader/ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ def __init__(
truncate_degraded_only: bool = False,
resolution_seconds: int = 1,
padding_value: float = 0.0,
scaling_range: [float, float] = None,
) -> None:
"""
Create a new reader for the New C-MAPSS dataset. The maximum RUL value is set
to 65 by default. The default channels are the four operating conditions,
to 65 by default. The default channels are the four operating conditions,
the 14 physical, and 14 virtual sensors in this order.
The default window size is, by default, the longest flight cycle in the
Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
self.run_split_dist = run_split_dist or self._get_default_split(self.fd)
self.resolution_seconds = resolution_seconds
self.padding_value = padding_value
self.scaling_range = scaling_range

if self.resolution_seconds > 1 and window_size is None:
warnings.warn(
Expand All @@ -189,6 +191,7 @@ def hparams(self):
"run_split_dist": self.run_split_dist,
"feature_select": self.feature_select,
"padding_value": self.padding_value,
"scaling_range": self.scaling_range,
}
)

Expand All @@ -214,10 +217,10 @@ def prepare_data(self) -> None:
"""
if not os.path.exists(self._NCMAPSS_ROOT):
_download_ncmapss(self._NCMAPSS_ROOT)
if not os.path.exists(self._get_scaler_path()):
features, _, _ = self._load_data("dev")
scaler = scaling.fit_scaler(features, MinMaxScaler())
scaling.save_scaler(scaler, self._get_scaler_path())
#if not os.path.exists(self._get_scaler_path()):
features, _, _ = self._load_data("dev")
scaler = scaling.fit_scaler(features, MinMaxScaler())
scaling.save_scaler(scaler, self._get_scaler_path())

def _get_scaler_path(self):
file_name = f"scaler_{self.fd}_{self.run_split_dist['dev']}.pkl"
Expand Down Expand Up @@ -301,7 +304,7 @@ def _select_units(self, units, split):
return [units[i] for i in self.run_split_dist[split]]

def _window_by_cycle(
self, features: np.ndarray, targets: np.ndarray, auxiliary: np.ndarray
self, features: np.ndarray, targets: np.ndarray, auxiliary: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
cycle_end_idx = self._get_end_idx(auxiliary[:, 1])
split_features = np.split(features, cycle_end_idx[:-1])
Expand Down
Empty file added tests/reader/test_ncmapps.py
Empty file.
16 changes: 16 additions & 0 deletions tests/reader/test_ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ def test_prepare_data(should_run, mocker):
mock_save_scaler.assert_not_called()



@pytest.mark.needs_data
@pytest.mark.parametrize("scaling_range", [(-1.0, 1.0), (0.0, 2.0)])
def test_scaling_range(scaling_range):
reader = NCmapssReader(fd=1, scaling_range=scaling_range)
reader.prepare_data()
features, _ = reader.load_split("dev")

reader = NCmapssReader(fd=1, scaling_range=(0, 1))
reader.prepare_data()
features_default, _ = reader.load_split("dev")

assert not np.array_equal(features[0][:, :, 1], features_default[0][:, :, 1])



@pytest.mark.needs_data
@pytest.mark.parametrize("fd", list(range(1, 8)))
@pytest.mark.parametrize("split", ["dev", "val", "test"])
Expand Down

0 comments on commit 757d9b6

Please sign in to comment.