Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make downloading from GDrive more robust #53

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rul_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from .baseline import BaselineDataModule, PretrainingBaselineDataModule
from .core import RulDataModule
from .reader import CmapssReader, FemtoReader, XjtuSyReader
from .reader import CmapssReader, FemtoReader, XjtuSyReader, NCmapssReader
from .reader.data_root import get_data_root, set_data_root
from .ssl import SemiSupervisedDataModule

Expand Down
2 changes: 1 addition & 1 deletion rul_datasets/reader/ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,12 @@ def _calc_default_window_size(self):


def _download_ncmapss(data_root):
os.makedirs(data_root)
with tempfile.TemporaryDirectory() as tmp_path:
print("Download N-C-MAPSS dataset from Google Drive")
download_path = os.path.join(tmp_path, "data.zip")
utils.download_gdrive_file(NCMAPSS_DRIVE_ID, download_path)
print("Extract N-C-MAPSS dataset")
os.makedirs(data_root)
with zipfile.ZipFile(download_path, mode="r") as f:
for zipinfo in f.infolist():
zipinfo.filename = os.path.basename(zipinfo.filename)
Expand Down
9 changes: 9 additions & 0 deletions rul_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ def download_gdrive_file(file_id: str, save_path: str) -> None:
if response.text.startswith("<!DOCTYPE html>"):
params = {"id": file_id, "confirm": "t"}
response = session.post(GDRIVE_URL_BASE, params=params, stream=True)
if response.status_code == 429:
raise RuntimeError(
"Download failed. Server returned 429. "
"This is usually caused by too many requests. "
"Please try again later."
)
elif not response.status_code == 200:
raise RuntimeError(f"Download failed. Server returned {response.status_code}")
_write_content(response, save_path)


Expand All @@ -112,6 +120,7 @@ def _write_content(response: requests.Response, save_path: str) -> None:
if chunk:
pbar.update(len(chunk))
f.write(chunk)
f.flush()
pbar.close()


Expand Down
Loading