diff --git a/kghub_downloader/download_utils.py b/kghub_downloader/download_utils.py index 3d5196c..14578e2 100644 --- a/kghub_downloader/download_utils.py +++ b/kghub_downloader/download_utils.py @@ -55,184 +55,187 @@ def download_from_yaml( pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) with open(yaml_file) as f: data = yaml.load(f, Loader=yaml.FullLoader) - # Limit to only tagged downloads, if tags are passed in - if tags: - data = [ - item - for item in data - if "tag" in item and item["tag"] and item["tag"] in tags - ] - - for item in tqdm(data, desc="Downloading files"): - if "url" not in item: - logging.error("Couldn't find url for source in {}".format(item)) - continue - if snippet_only and (item["local_name"])[-3:] in [ - "zip", - ".gz", - ]: # Can't truncate compressed files - logging.error( - "Asked to download snippets; can't snippet {}".format(item) - ) - continue - local_name = ( - item["local_name"] - if "local_name" in item and item["local_name"] - else item["url"].split("/")[-1] + # Limit to only tagged downloads, if tags are passed in + if tags: + data = [ + item + for item in data + if "tag" in item and item["tag"] and item["tag"] in tags + ] + + for item in tqdm(data, desc="Downloading files"): + if "url" not in item: + logging.error("Couldn't find url for source in {}".format(item)) + continue + if snippet_only and (item["local_name"])[-3:] in [ + "zip", + ".gz", + ]: # Can't truncate compressed files + logging.error( + "Asked to download snippets; can't snippet {}".format(item) ) - outfile = os.path.join(output_dir, local_name) + continue - logging.info("Retrieving %s from %s" % (outfile, item["url"])) + local_name = ( + item["local_name"] + if "local_name" in item and item["local_name"] + else item["url"].split("/")[-1] + ) + outfile = os.path.join(output_dir, local_name) + + logging.info("Retrieving %s from %s" % (outfile, item["url"])) - if "local_name" in item: - local_file_dir = os.path.join( - output_dir, os.path.dirname(item["local_name"]) + if "local_name" in item: + local_file_dir = os.path.join( + output_dir, os.path.dirname(item["local_name"]) + ) + if not os.path.exists(local_file_dir): + logging.info(f"Creating local directory {local_file_dir}") + pathlib.Path(local_file_dir).mkdir(parents=True, exist_ok=True) + + if os.path.exists(outfile): + if ignore_cache: + logging.info("Deleting cached version of {}".format(outfile)) + os.remove(outfile) + else: + logging.info("Using cached version of {}".format(outfile)) + continue + + # Download file + if "api" in item: + download_from_api(item, outfile) + if "url" in item: + url = parse_url(item["url"]) + if url.startswith("gs://"): + Blob.from_string(url, client=storage.Client()).download_to_filename( + outfile + ) + elif url.startswith("s3://"): + s3 = boto3.client("s3") + bucket_name = url.split("/")[2] + remote_file = "/".join(url.split("/")[3:]) + s3.download_file(bucket_name, remote_file, outfile) + elif url.startswith("ftp"): + glob = None + if "glob" in item: + glob = item["glob"] + ftp_username = ( + os.getenv("FTP_USERNAME") if os.getenv("FTP_USERNAME") else None + ) + ftp_password = ( + os.getenv("FTP_PASSWORD") if os.getenv("FTP_PASSWORD") else None ) - if not os.path.exists(local_file_dir): - logging.info(f"Creating local directory {local_file_dir}") - pathlib.Path(local_file_dir).mkdir(parents=True, exist_ok=True) - - if os.path.exists(outfile): - if ignore_cache: - logging.info("Deleting cached version of {}".format(outfile)) - os.remove(outfile) + host = url.split("/")[0] + path = "/".join(url.split("/")[1:]) + ftp = ftplib.FTP(host) + ftp.login(ftp_username, ftp_password) + download_via_ftp(ftp, path, outfile, glob) + elif any( + url.startswith(str(i)) + for i in list(GDOWN_MAP.keys()) + list(GDOWN_MAP.values()) + ): + # Check if url starts with a key or a value + for key, value in GDOWN_MAP.items(): + if url.startswith(str(value)): + # If value, then download the file directly + gdown.download(url, output=outfile) + break + elif url.startswith(str(key)): + # If key, replace key by value and then download + new_url = url.replace(str(key) + ":", str(value)) + gdown.download(new_url, output=outfile) + break else: - logging.info("Using cached version of {}".format(outfile)) - continue - - # Download file - if "api" in item: - download_from_api(item, outfile) - if "url" in item: - url = parse_url(item["url"]) - if url.startswith("gs://"): - Blob.from_string(url, client=storage.Client()).download_to_filename( - outfile + # If the loop completes without breaking (i.e., no match found), throw an error + raise ValueError("Invalid URL") + elif url.startswith("git://"): + url_split = url.split("/") + repo_owner = url_split[-3] + repo_name = url_split[-2] + asset_name = url_split[-1] + asset_url = None + api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases" + # Get the list of releases + response = requests.get(api_url) + response.raise_for_status() + releases = response.json() + + if not releases: + print("No releases found for this repository.") + sys.exit(1) + + # Check if a specific tag is provided + if "tag" in item: + # Find the release with the specified tag + tagged_release = next( + ( + release + for release in releases + if release["tag_name"] == item["tag"] + ), + None, ) - elif url.startswith("s3://"): - s3 = boto3.client("s3") - bucket_name = url.split("/")[2] - remote_file = "/".join(url.split("/")[3:]) - s3.download_file(bucket_name, remote_file, outfile) - elif url.startswith("ftp"): - glob = None - if "glob" in item: - glob = item["glob"] - ftp_username = ( - os.getenv("FTP_USERNAME") if os.getenv("FTP_USERNAME") else None - ) - ftp_password = ( - os.getenv("FTP_PASSWORD") if os.getenv("FTP_PASSWORD") else None - ) - host = url.split("/")[0] - path = "/".join(url.split("/")[1:]) - ftp = ftplib.FTP(host) - ftp.login(ftp_username, ftp_password) - download_via_ftp(ftp, path, outfile, glob) - elif any( - url.startswith(str(i)) - for i in list(GDOWN_MAP.keys()) + list(GDOWN_MAP.values()) - ): - # Check if url starts with a key or a value - for key, value in GDOWN_MAP.items(): - if url.startswith(str(value)): - # If value, then download the file directly - gdown.download(url, output=outfile) - break - elif url.startswith(str(key)): - # If key, replace key by value and then download - new_url = url.replace(str(key) + ":", str(value)) - gdown.download(new_url, output=outfile) - break - else: - # If the loop completes without breaking (i.e., no match found), throw an error - raise ValueError("Invalid URL") - elif url.startswith("git://"): - url_split = url.split("/") - repo_owner = url_split[-3] - repo_name = url_split[-2] - asset_name = url_split[-1] - asset_url = None - api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases" - # Get the list of releases - response = requests.get(api_url) - response.raise_for_status() - releases = response.json() - - if not releases: - print("No releases found for this repository.") - sys.exit(1) - - # Check if a specific tag is provided - if "tag" in item: - # Find the release with the specified tag - tagged_release = next( - ( - release - for release in releases - if release["tag_name"] == item["tag"] - ), - None, - ) - if tagged_release: - for asset in tagged_release.get("assets", []): - if asset["name"] == asset_name: - asset_url = asset["browser_download_url"] - break - - # If no asset found in the specified tag or no tag provided, check other releases - if not asset_url: - for release in releases: - for asset in release.get("assets", []): - if asset["name"] == asset_name: - asset_url = asset["browser_download_url"] - break - if asset_url: + if tagged_release: + for asset in tagged_release.get("assets", []): + if asset["name"] == asset_name: + asset_url = asset["browser_download_url"] break - if not asset_url: - print(f"Asset '{asset_name}' not found in any release.") - sys.exit(1) + # If no asset found in the specified tag or no tag provided, check other releases + if not asset_url: + for release in releases: + for asset in release.get("assets", []): + if asset["name"] == asset_name: + asset_url = asset["browser_download_url"] + break + if asset_url: + break - # Download the asset - response = requests.get(asset_url, stream=True) - response.raise_for_status() - with open(outfile, "wb") as file: - for chunk in response.iter_content(chunk_size=8192): - file.write(chunk) - print(f"Downloaded {asset_name}") + if not asset_url: + print(f"Asset '{asset_name}' not found in any release.") + sys.exit(1) - else: - req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) - try: - with urlopen(req) as response, open(outfile, "wb") as out_file: # type: ignore - if snippet_only: - data = response.read( - 5120 - ) # first 5 kB of a `bytes` object - else: - data = response.read() # a `bytes` object - out_file.write(data) - if snippet_only: # Need to clean up the outfile - in_file = open(outfile, "r+") - in_lines = in_file.read() - in_file.close() - splitlines = in_lines.split("\n") - outstring = "\n".join(splitlines[:-1]) - cleanfile = open(outfile, "w+") - for i in range(len(outstring)): - cleanfile.write(outstring[i]) - cleanfile.close() - except URLError: - logging.error(f"Failed to download: {url}") - raise - - # If mirror, upload to remote storage - if mirror: - mirror_to_bucket( - local_file=outfile, bucket_url=mirror, remote_file=local_name - ) + # Download the asset + response = requests.get(asset_url, stream=True) + response.raise_for_status() + with open(outfile, "wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + print(f"Downloaded {asset_name}") + + else: + req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) + try: + with urlopen(req) as response: # type: ignore + if snippet_only: + data = response.read( + 5120 + ) # first 5 kB of a `bytes` object + else: + data = response.read() # a `bytes` object + + with open(outfile, "wb") as out_file: + out_file.write(data) + if snippet_only: # Need to clean up the outfile + in_file = open(outfile, "r+") + in_lines = in_file.read() + in_file.close() + splitlines = in_lines.split("\n") + outstring = "\n".join(splitlines[:-1]) + cleanfile = open(outfile, "w+") + for i in range(len(outstring)): + cleanfile.write(outstring[i]) + cleanfile.close() + except URLError: + logging.error(f"Failed to download: {url}") + raise + + # If mirror, upload to remote storage + if mirror: + mirror_to_bucket( + local_file=outfile, bucket_url=mirror, remote_file=local_name + ) return None @@ -309,11 +312,11 @@ def download_from_api(yaml_item, outfile) -> None: query_data = compress_json.local_load( os.path.join(os.getcwd(), yaml_item["query_file"]) ) - output = open(outfile, "w") records = elastic_search_query( es_conn, index=yaml_item["index"], query=query_data ) - json.dump(records, output) + with open(outfile, "w") as output: + json.dump(records, output) return None else: raise RuntimeError(f"API {yaml_item['api']} not supported")