Skip to content

Commit

Permalink
feat: add auto-download for N-C-MAPSS
Browse files Browse the repository at this point in the history
  • Loading branch information
tilman151 committed Jan 12, 2024
1 parent dcec90d commit f9a9aa6
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions rul_datasets/reader/ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@
corrupted. The dataset authors were already contacted about this issue."""

import os
import tempfile
import warnings
import zipfile
from typing import Tuple, List, Optional, Union, Dict

import h5py # type: ignore[import]
import numpy as np
from sklearn.preprocessing import MinMaxScaler # type: ignore[import]

from rul_datasets import utils
from rul_datasets.reader.data_root import get_data_root
from rul_datasets.reader import AbstractReader, scaling


NCMAPSS_DRIVE_ID = "1X9pHm2E3U0bZZbXIhJubVGSL3rtzqFkn"


class NCmapssReader(AbstractReader):
"""
This reader provides access to the New C-MAPSS Turbofan Degradation dataset. Each
Expand Down Expand Up @@ -206,6 +212,8 @@ def prepare_data(self) -> None:
data is then split into development and validation set. Afterward, a scaler
is fit on the development features if it was not already done previously.
"""
if not os.path.exists(self._NCMAPSS_ROOT):
_download_ncmapss(self._NCMAPSS_ROOT)
if not os.path.exists(self._get_scaler_path()):
features, _, _ = self._load_data("dev")
scaler = scaling.fit_scaler(features, MinMaxScaler())
Expand Down Expand Up @@ -352,3 +360,17 @@ def _calc_default_window_size(self):
max_window_size.append(max(*[len(f) for f in split_features]))

return max(*max_window_size)


def _download_ncmapss(data_root):
os.makedirs(data_root)
with tempfile.TemporaryDirectory() as tmp_path:
print("Download N-C-MAPSS dataset from Google Drive")
download_path = os.path.join(tmp_path, "data.zip")
utils.download_gdrive_file(NCMAPSS_DRIVE_ID, download_path)
print("Extract N-C-MAPSS dataset")
with zipfile.ZipFile(download_path, mode="r") as f:
for zipinfo in f.infolist():
zipinfo.filename = os.path.basename(zipinfo.filename)
if zipinfo.filename:
f.extract(zipinfo, data_root)

0 comments on commit f9a9aa6

Please sign in to comment.