diff --git a/haystack/utils/import_utils.py b/haystack/utils/import_utils.py index 5038c38bf4..6945fbf3ea 100644 --- a/haystack/utils/import_utils.py +++ b/haystack/utils/import_utils.py @@ -1,14 +1,19 @@ -import logging +import gzip import importlib import importlib.util -from typing import Optional, Tuple, List -from urllib.parse import urlparse, unquote -from os.path import splitext, basename +import io +import logging +import zipfile +from os.path import basename, splitext +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union +from urllib.parse import unquote, urlparse + +import requests from haystack.errors import DatasetsError from haystack.schema import Document - logger = logging.getLogger(__name__) @@ -55,5 +60,54 @@ def get_filename_extension_from_url(url: str) -> Tuple[str, str]: return file_name, archive_extension +def fetch_archive_from_http( + url: str, + output_dir: str, + proxies: Optional[Dict[str, str]] = None, + timeout: Union[float, Tuple[float, float]] = 10.0, +) -> bool: + """ + Fetch an archive (zip or gz) from a url via http and extract content to an output directory. + :param url: http address + :param output_dir: local path + :param proxies: proxies details as required by requests library + :param timeout: How many seconds to wait for the server to send data before giving up, + as a float, or a :ref:`(connect timeout, read timeout) ` tuple. + Defaults to 10 seconds. + :return: if anything got fetched + """ + # verify & prepare local directory + path = Path(output_dir) + if not path.exists(): + path.mkdir(parents=True) + + is_not_empty = len(list(Path(path).rglob("*"))) > 0 + if is_not_empty: + logger.info("Found data stored in '%s'. Delete this first if you really want to fetch new data.", output_dir) + return False + else: + logger.info("Fetching from %s to '%s'", url, output_dir) + + file_name, archive_extension = get_filename_extension_from_url(url) + request_data = requests.get(url, proxies=proxies, timeout=timeout) + + if archive_extension == "zip": + zip_archive = zipfile.ZipFile(io.BytesIO(request_data.content)) + zip_archive.extractall(output_dir) + elif archive_extension == "gz" and not "tar.gz" in url: + gzip_archive = gzip.GzipFile(fileobj=io.BytesIO(request_data.content)) + file_content = gzip_archive.read() + with open(f"{output_dir}/{file_name}", "wb") as file: + file.write(file_content) + else: + logger.warning( + "Skipped url %s as file type is not supported here. " + "See haystack documentation for support of more file types", + url, + ) + + return True + + def is_whisper_available(): return importlib.util.find_spec("whisper") is not None diff --git a/releasenotes/notes/safe-fetch-4ba829def3241eec.yaml b/releasenotes/notes/safe-fetch-4ba829def3241eec.yaml new file mode 100644 index 0000000000..921e88dbb5 --- /dev/null +++ b/releasenotes/notes/safe-fetch-4ba829def3241eec.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add previously removed `fetch_archive_from_http` util function to fetch zip and gzip archives from url