diff --git a/CHANGELOG.md b/CHANGELOG.md index bc892679..9e043d21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project uses [Semantic Versioning](https://semver.org/spec/v2.0.0.html) ## [Unreleased] +- Fix `earthaccess.download` to not ignore errors by default + ([#581](https://github.com/nsidc/earthaccess/issues/581)) + ([**@Sherwin-14**](https://github.com/Sherwin-14), + [**@chuckwondo**](https://github.com/chuckwondo), + [**@mfisher87**](https://github.com/mfisher87)) + ### Changed - Use built-in `assert` statements instead of `unittest` assertions in diff --git a/earthaccess/api.py b/earthaccess/api.py index 0a50e563..5ad75ed5 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -3,7 +3,7 @@ import requests import s3fs from fsspec import AbstractFileSystem -from typing_extensions import Any, Dict, List, Optional, Union, deprecated +from typing_extensions import Any, Dict, List, Mapping, Optional, Union, deprecated import earthaccess from earthaccess.services import DataServices @@ -205,6 +205,7 @@ def download( local_path: Optional[str], provider: Optional[str] = None, threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: """Retrieves data granules from a remote storage system. @@ -217,6 +218,9 @@ def download( local_path: local directory to store the remote data granules provider: if we download a list of URLs, we need to specify the provider. threads: parallel number of threads to use to download the files, adjust as necessary, default = 8 + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: List of downloaded files @@ -225,12 +229,19 @@ def download( Exception: A file download failed. """ provider = _normalize_location(provider) + pqdm_kwargs = { + "exception_behavior": "immediate", + "n_jobs": threads, + **(pqdm_kwargs or {}), + } if isinstance(granules, DataGranule): granules = [granules] elif isinstance(granules, str): granules = [granules] try: - results = earthaccess.__store__.get(granules, local_path, provider, threads) + results = earthaccess.__store__.get( + granules, local_path, provider, threads, pqdm_kwargs + ) except AttributeError as err: logger.error( f"{err}: You must call earthaccess.login() before you can download data" @@ -242,6 +253,7 @@ def download( def open( granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[AbstractFileSystem]: """Returns a list of file-like objects that can be used to access files hosted on S3 or HTTPS by third party libraries like xarray. @@ -250,12 +262,21 @@ def open( granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`. If a list of URLs is passed, we need to specify the data provider. provider: e.g. POCLOUD, NSIDC_CPRD, etc. + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: A list of "file pointers" to remote (i.e. s3 or https) files. """ provider = _normalize_location(provider) - results = earthaccess.__store__.open(granules=granules, provider=provider) + pqdm_kwargs = { + "exception_behavior": "immediate", + **(pqdm_kwargs or {}), + } + results = earthaccess.__store__.open( + granules=granules, provider=provider, pqdm_kwargs=pqdm_kwargs + ) return results diff --git a/earthaccess/store.py b/earthaccess/store.py index fd399d44..f7b5c85e 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -64,12 +64,20 @@ def _open_files( url_mapping: Mapping[str, Union[DataGranule, None]], fs: fsspec.AbstractFileSystem, threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.spec.AbstractBufferedFile]: def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile: url, granule = data return EarthAccessFile(fs.open(url), granule) # type: ignore - fileset = pqdm(url_mapping.items(), multi_thread_open, n_jobs=threads) + pqdm_kwargs = { + "exception_behavior": "immediate", + **(pqdm_kwargs or {}), + } + + fileset = pqdm( + url_mapping.items(), multi_thread_open, n_jobs=threads, **pqdm_kwargs + ) return fileset @@ -336,6 +344,7 @@ def open( self, granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.spec.AbstractBufferedFile]: """Returns a list of file-like objects that can be used to access files hosted on S3 or HTTPS by third party libraries like xarray. @@ -344,12 +353,15 @@ def open( granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`. If a list of URLs is passed, we need to specify the data provider. provider: e.g. POCLOUD, NSIDC_CPRD, etc. + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: A list of "file pointers" to remote (i.e. s3 or https) files. """ if len(granules): - return self._open(granules, provider) + return self._open(granules, provider, pqdm_kwargs) return [] @singledispatchmethod @@ -357,6 +369,7 @@ def _open( self, granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: raise NotImplementedError("granules should be a list of DataGranule or URLs") @@ -420,6 +433,7 @@ def _open_urls( granules: List[str], provider: Optional[str] = None, threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: fileset: List = [] @@ -447,6 +461,7 @@ def _open_urls( url_mapping, fs=s3_fs, threads=threads, + pqdm_kwargs=pqdm_kwargs, ) except Exception as e: raise RuntimeError( @@ -466,7 +481,7 @@ def _open_urls( raise ValueError( "We cannot open S3 links when we are not in-region, try using HTTPS links" ) - fileset = self._open_urls_https(url_mapping, threads) + fileset = self._open_urls_https(url_mapping, threads, pqdm_kwargs) return fileset def get( @@ -475,6 +490,7 @@ def get( local_path: Union[Path, str, None] = None, provider: Optional[str] = None, threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: """Retrieves data granules from a remote storage system. @@ -491,6 +507,9 @@ def get( provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions threads: Parallel number of threads to use to download the files; adjust as necessary, default = 8. + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: List of downloaded files @@ -503,7 +522,7 @@ def get( local_path = Path(local_path) if len(granules): - files = self._get(granules, local_path, provider, threads) + files = self._get(granules, local_path, provider, threads, pqdm_kwargs) return files else: raise ValueError("List of URLs or DataGranule instances expected") @@ -515,6 +534,7 @@ def _get( local_path: Path, provider: Optional[str] = None, threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: """Retrieves data granules from a remote storage system. @@ -531,6 +551,9 @@ def _get( provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions threads: Parallel number of threads to use to download the files; adjust as necessary, default = 8. + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: None @@ -544,6 +567,7 @@ def _get_urls( local_path: Path, provider: Optional[str] = None, threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: data_links = granules downloaded_files: List = [] @@ -565,7 +589,9 @@ def _get_urls( else: # if we are not in AWS - return self._download_onprem_granules(data_links, local_path, threads) + return self._download_onprem_granules( + data_links, local_path, threads, pqdm_kwargs + ) @_get.register def _get_granules( @@ -574,6 +600,7 @@ def _get_granules( local_path: Path, provider: Optional[str] = None, threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: data_links: List = [] downloaded_files: List = [] @@ -614,7 +641,9 @@ def _get_granules( else: # if the data are cloud-based, but we are not in AWS, # it will be downloaded as if it was on prem - return self._download_onprem_granules(data_links, local_path, threads) + return self._download_onprem_granules( + data_links, local_path, threads, pqdm_kwargs + ) def _download_file(self, url: str, directory: Path) -> str: """Download a single file from an on-prem location, a DAAC data center. @@ -652,7 +681,11 @@ def _download_file(self, url: str, directory: Path) -> str: return str(path) def _download_onprem_granules( - self, urls: List[str], directory: Path, threads: int = 8 + self, + urls: List[str], + directory: Path, + threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: """Downloads a list of URLS into the data directory. @@ -661,6 +694,9 @@ def _download_onprem_granules( directory: local directory to store the downloaded files threads: parallel number of threads to use to download the files; adjust as necessary, default = 8 + pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. + See pqdm documentation for available options. Default is to use immediate exception behavior + and the number of jobs specified by the `threads` parameter. Returns: A list of local filepaths to which the files were downloaded. @@ -674,11 +710,13 @@ def _download_onprem_granules( directory.mkdir(parents=True, exist_ok=True) arguments = [(url, directory) for url in urls] + results = pqdm( arguments, self._download_file, n_jobs=threads, argument_type="args", + **pqdm_kwargs, ) return results @@ -686,11 +724,12 @@ def _open_urls_https( self, url_mapping: Mapping[str, Union[DataGranule, None]], threads: int = 8, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.AbstractFileSystem]: https_fs = self.get_fsspec_session() try: - return _open_files(url_mapping, https_fs, threads) + return _open_files(url_mapping, https_fs, threads, pqdm_kwargs) except Exception: logger.exception( "An exception occurred while trying to access remote files via HTTPS" diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py new file mode 100644 index 00000000..20980e35 --- /dev/null +++ b/tests/unit/test_api.py @@ -0,0 +1,50 @@ +from unittest.mock import Mock + +import earthaccess +import pytest + + +def test_download_immediate_failure(monkeypatch): + earthaccess.login() + + results = earthaccess.search_data( + short_name="ATL06", + bounding_box=(-10, 20, 10, 50), + temporal=("1999-02", "2019-03"), + count=10, + ) + + def mock_get(*args, **kwargs): + raise Exception("Download failed") + + mock_store = Mock() + monkeypatch.setattr(earthaccess, "__store__", mock_store) + monkeypatch.setattr(mock_store, "get", mock_get) + + with pytest.raises(Exception, match="Download failed"): + earthaccess.download(results, "/home/download-folder") + + +def test_download_deferred_failure(monkeypatch): + earthaccess.login() + + results = earthaccess.search_data( + short_name="ATL06", + bounding_box=(-10, 20, 10, 50), + temporal=("1999-02", "2019-03"), + count=10, + ) + + def mock_get(*args, **kwargs): + return [Exception("Download failed")] * len(results) + + mock_store = Mock() + monkeypatch.setattr(earthaccess, "__store__", mock_store) + monkeypatch.setattr(mock_store, "get", mock_get) + + results = earthaccess.download( + results, "/home/download-folder", None, 8, {"exception_behavior": "deferred"} + ) + + assert all(isinstance(e, Exception) for e in results) + assert len(results) == 10