From dcec90d03a2c4d331a07c5a05077f1aad2effed8 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 12 Jan 2024 17:52:04 +0100 Subject: [PATCH 1/2] feat: add function to download from gdrive --- rul_datasets/utils.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/rul_datasets/utils.py b/rul_datasets/utils.py index 9ed6ca3..faee551 100644 --- a/rul_datasets/utils.py +++ b/rul_datasets/utils.py @@ -7,6 +7,9 @@ from tqdm import tqdm # type: ignore +GDRIVE_URL_BASE = "https://docs.google.com/uc?export=download" + + def get_files_in_path(path: str, condition: Optional[Callable] = None) -> List[str]: """ Return the paths of all files in a path that satisfy a condition in alphabetical @@ -88,10 +91,28 @@ def download_file(url: str, save_path: str) -> None: response = requests.get(url, stream=True) if not response.status_code == 200: raise RuntimeError(f"Download failed. Server returned {response.status_code}") - content_len = int(response.headers["Content-Length"]) // 1024 + _write_content(response, save_path) + + +def download_gdrive_file(file_id: str, save_path: str) -> None: + session = requests.Session() + response = session.get(GDRIVE_URL_BASE, params={"id": file_id}, stream=True) + if response.text.startswith(""): + params = {"id": file_id, "confirm": "t"} + response = session.post(GDRIVE_URL_BASE, params=params, stream=True) + _write_content(response, save_path) + + +def _write_content(response: requests.Response, save_path: str) -> None: + content_len = int(response.headers["Content-Length"]) with open(save_path, mode="wb") as f: - for data in tqdm(response.iter_content(chunk_size=1024), total=content_len): - f.write(data) + pbar = tqdm(unit="B", unit_scale=True, unit_divisor=1024, total=content_len) + pbar.clear() + for chunk in response.iter_content(chunk_size=32768): + if chunk: + pbar.update(len(chunk)) + f.write(chunk) + pbar.close() def to_tensor( From f9a9aa6b088a4aef6a741dbe8943bbfe66e59d06 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 12 Jan 2024 17:52:27 +0100 Subject: [PATCH 2/2] feat: add auto-download for N-C-MAPSS --- rul_datasets/reader/ncmapss.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/rul_datasets/reader/ncmapss.py b/rul_datasets/reader/ncmapss.py index fa5c2aa..8a260c5 100644 --- a/rul_datasets/reader/ncmapss.py +++ b/rul_datasets/reader/ncmapss.py @@ -12,17 +12,23 @@ corrupted. The dataset authors were already contacted about this issue.""" import os +import tempfile import warnings +import zipfile from typing import Tuple, List, Optional, Union, Dict import h5py # type: ignore[import] import numpy as np from sklearn.preprocessing import MinMaxScaler # type: ignore[import] +from rul_datasets import utils from rul_datasets.reader.data_root import get_data_root from rul_datasets.reader import AbstractReader, scaling +NCMAPSS_DRIVE_ID = "1X9pHm2E3U0bZZbXIhJubVGSL3rtzqFkn" + + class NCmapssReader(AbstractReader): """ This reader provides access to the New C-MAPSS Turbofan Degradation dataset. Each @@ -206,6 +212,8 @@ def prepare_data(self) -> None: data is then split into development and validation set. Afterward, a scaler is fit on the development features if it was not already done previously. """ + if not os.path.exists(self._NCMAPSS_ROOT): + _download_ncmapss(self._NCMAPSS_ROOT) if not os.path.exists(self._get_scaler_path()): features, _, _ = self._load_data("dev") scaler = scaling.fit_scaler(features, MinMaxScaler()) @@ -352,3 +360,17 @@ def _calc_default_window_size(self): max_window_size.append(max(*[len(f) for f in split_features])) return max(*max_window_size) + + +def _download_ncmapss(data_root): + os.makedirs(data_root) + with tempfile.TemporaryDirectory() as tmp_path: + print("Download N-C-MAPSS dataset from Google Drive") + download_path = os.path.join(tmp_path, "data.zip") + utils.download_gdrive_file(NCMAPSS_DRIVE_ID, download_path) + print("Extract N-C-MAPSS dataset") + with zipfile.ZipFile(download_path, mode="r") as f: + for zipinfo in f.infolist(): + zipinfo.filename = os.path.basename(zipinfo.filename) + if zipinfo.filename: + f.extract(zipinfo, data_root)