diff --git a/rul_datasets/reader/ncmapss.py b/rul_datasets/reader/ncmapss.py index fa5c2aa..8a260c5 100644 --- a/rul_datasets/reader/ncmapss.py +++ b/rul_datasets/reader/ncmapss.py @@ -12,17 +12,23 @@ corrupted. The dataset authors were already contacted about this issue.""" import os +import tempfile import warnings +import zipfile from typing import Tuple, List, Optional, Union, Dict import h5py # type: ignore[import] import numpy as np from sklearn.preprocessing import MinMaxScaler # type: ignore[import] +from rul_datasets import utils from rul_datasets.reader.data_root import get_data_root from rul_datasets.reader import AbstractReader, scaling +NCMAPSS_DRIVE_ID = "1X9pHm2E3U0bZZbXIhJubVGSL3rtzqFkn" + + class NCmapssReader(AbstractReader): """ This reader provides access to the New C-MAPSS Turbofan Degradation dataset. Each @@ -206,6 +212,8 @@ def prepare_data(self) -> None: data is then split into development and validation set. Afterward, a scaler is fit on the development features if it was not already done previously. """ + if not os.path.exists(self._NCMAPSS_ROOT): + _download_ncmapss(self._NCMAPSS_ROOT) if not os.path.exists(self._get_scaler_path()): features, _, _ = self._load_data("dev") scaler = scaling.fit_scaler(features, MinMaxScaler()) @@ -352,3 +360,17 @@ def _calc_default_window_size(self): max_window_size.append(max(*[len(f) for f in split_features])) return max(*max_window_size) + + +def _download_ncmapss(data_root): + os.makedirs(data_root) + with tempfile.TemporaryDirectory() as tmp_path: + print("Download N-C-MAPSS dataset from Google Drive") + download_path = os.path.join(tmp_path, "data.zip") + utils.download_gdrive_file(NCMAPSS_DRIVE_ID, download_path) + print("Extract N-C-MAPSS dataset") + with zipfile.ZipFile(download_path, mode="r") as f: + for zipinfo in f.infolist(): + zipinfo.filename = os.path.basename(zipinfo.filename) + if zipinfo.filename: + f.extract(zipinfo, data_root)