Skip to content

Commit

Permalink
feat: add auto-download for ncmapss (#51)
Browse files Browse the repository at this point in the history
* feat: add function to download from gdrive

* feat: add auto-download for N-C-MAPSS
  • Loading branch information
tilman151 authored Jan 12, 2024
1 parent d47897f commit 90ec9a7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
22 changes: 22 additions & 0 deletions rul_datasets/reader/ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@
corrupted. The dataset authors were already contacted about this issue."""

import os
import tempfile
import warnings
import zipfile
from typing import Tuple, List, Optional, Union, Dict

import h5py # type: ignore[import]
import numpy as np
from sklearn.preprocessing import MinMaxScaler # type: ignore[import]

from rul_datasets import utils
from rul_datasets.reader.data_root import get_data_root
from rul_datasets.reader import AbstractReader, scaling


NCMAPSS_DRIVE_ID = "1X9pHm2E3U0bZZbXIhJubVGSL3rtzqFkn"


class NCmapssReader(AbstractReader):
"""
This reader provides access to the New C-MAPSS Turbofan Degradation dataset. Each
Expand Down Expand Up @@ -206,6 +212,8 @@ def prepare_data(self) -> None:
data is then split into development and validation set. Afterward, a scaler
is fit on the development features if it was not already done previously.
"""
if not os.path.exists(self._NCMAPSS_ROOT):
_download_ncmapss(self._NCMAPSS_ROOT)
if not os.path.exists(self._get_scaler_path()):
features, _, _ = self._load_data("dev")
scaler = scaling.fit_scaler(features, MinMaxScaler())
Expand Down Expand Up @@ -352,3 +360,17 @@ def _calc_default_window_size(self):
max_window_size.append(max(*[len(f) for f in split_features]))

return max(*max_window_size)


def _download_ncmapss(data_root):
os.makedirs(data_root)
with tempfile.TemporaryDirectory() as tmp_path:
print("Download N-C-MAPSS dataset from Google Drive")
download_path = os.path.join(tmp_path, "data.zip")
utils.download_gdrive_file(NCMAPSS_DRIVE_ID, download_path)
print("Extract N-C-MAPSS dataset")
with zipfile.ZipFile(download_path, mode="r") as f:
for zipinfo in f.infolist():
zipinfo.filename = os.path.basename(zipinfo.filename)
if zipinfo.filename:
f.extract(zipinfo, data_root)
27 changes: 24 additions & 3 deletions rul_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from tqdm import tqdm # type: ignore


GDRIVE_URL_BASE = "https://docs.google.com/uc?export=download"


def get_files_in_path(path: str, condition: Optional[Callable] = None) -> List[str]:
"""
Return the paths of all files in a path that satisfy a condition in alphabetical
Expand Down Expand Up @@ -88,10 +91,28 @@ def download_file(url: str, save_path: str) -> None:
response = requests.get(url, stream=True)
if not response.status_code == 200:
raise RuntimeError(f"Download failed. Server returned {response.status_code}")
content_len = int(response.headers["Content-Length"]) // 1024
_write_content(response, save_path)


def download_gdrive_file(file_id: str, save_path: str) -> None:
session = requests.Session()
response = session.get(GDRIVE_URL_BASE, params={"id": file_id}, stream=True)
if response.text.startswith("<!DOCTYPE html>"):
params = {"id": file_id, "confirm": "t"}
response = session.post(GDRIVE_URL_BASE, params=params, stream=True)
_write_content(response, save_path)


def _write_content(response: requests.Response, save_path: str) -> None:
content_len = int(response.headers["Content-Length"])
with open(save_path, mode="wb") as f:
for data in tqdm(response.iter_content(chunk_size=1024), total=content_len):
f.write(data)
pbar = tqdm(unit="B", unit_scale=True, unit_divisor=1024, total=content_len)
pbar.clear()
for chunk in response.iter_content(chunk_size=32768):
if chunk:
pbar.update(len(chunk))
f.write(chunk)
pbar.close()


def to_tensor(
Expand Down

0 comments on commit 90ec9a7

Please sign in to comment.