diff --git a/rul_datasets/__init__.py b/rul_datasets/__init__.py index 4781997..e1b4c02 100644 --- a/rul_datasets/__init__.py +++ b/rul_datasets/__init__.py @@ -9,7 +9,7 @@ ) from .baseline import BaselineDataModule, PretrainingBaselineDataModule from .core import RulDataModule -from .reader import CmapssReader, FemtoReader, XjtuSyReader +from .reader import CmapssReader, FemtoReader, XjtuSyReader, NCmapssReader from .reader.data_root import get_data_root, set_data_root from .ssl import SemiSupervisedDataModule diff --git a/rul_datasets/reader/ncmapss.py b/rul_datasets/reader/ncmapss.py index 8a260c5..1b98f6f 100644 --- a/rul_datasets/reader/ncmapss.py +++ b/rul_datasets/reader/ncmapss.py @@ -363,12 +363,12 @@ def _calc_default_window_size(self): def _download_ncmapss(data_root): - os.makedirs(data_root) with tempfile.TemporaryDirectory() as tmp_path: print("Download N-C-MAPSS dataset from Google Drive") download_path = os.path.join(tmp_path, "data.zip") utils.download_gdrive_file(NCMAPSS_DRIVE_ID, download_path) print("Extract N-C-MAPSS dataset") + os.makedirs(data_root) with zipfile.ZipFile(download_path, mode="r") as f: for zipinfo in f.infolist(): zipinfo.filename = os.path.basename(zipinfo.filename) diff --git a/rul_datasets/utils.py b/rul_datasets/utils.py index faee551..d88f5ae 100644 --- a/rul_datasets/utils.py +++ b/rul_datasets/utils.py @@ -100,6 +100,14 @@ def download_gdrive_file(file_id: str, save_path: str) -> None: if response.text.startswith(""): params = {"id": file_id, "confirm": "t"} response = session.post(GDRIVE_URL_BASE, params=params, stream=True) + if response.status_code == 429: + raise RuntimeError( + "Download failed. Server returned 429. " + "This is usually caused by too many requests. " + "Please try again later." + ) + elif not response.status_code == 200: + raise RuntimeError(f"Download failed. Server returned {response.status_code}") _write_content(response, save_path) @@ -112,6 +120,7 @@ def _write_content(response: requests.Response, save_path: str) -> None: if chunk: pbar.update(len(chunk)) f.write(chunk) + f.flush() pbar.close()