Skip to content

Commit

Permalink
Merge pull request #40 from monarch-initiative/file-open-fixes
Browse files Browse the repository at this point in the history
File open context manager fixes
  • Loading branch information
ptgolden authored Sep 12, 2024
2 parents 2c0f3d2 + d957ece commit 50d8506
Showing 1 changed file with 172 additions and 169 deletions.
341 changes: 172 additions & 169 deletions kghub_downloader/download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,184 +55,187 @@ def download_from_yaml(
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
with open(yaml_file) as f:
data = yaml.load(f, Loader=yaml.FullLoader)
# Limit to only tagged downloads, if tags are passed in
if tags:
data = [
item
for item in data
if "tag" in item and item["tag"] and item["tag"] in tags
]

for item in tqdm(data, desc="Downloading files"):
if "url" not in item:
logging.error("Couldn't find url for source in {}".format(item))
continue
if snippet_only and (item["local_name"])[-3:] in [
"zip",
".gz",
]: # Can't truncate compressed files
logging.error(
"Asked to download snippets; can't snippet {}".format(item)
)
continue

local_name = (
item["local_name"]
if "local_name" in item and item["local_name"]
else item["url"].split("/")[-1]
# Limit to only tagged downloads, if tags are passed in
if tags:
data = [
item
for item in data
if "tag" in item and item["tag"] and item["tag"] in tags
]

for item in tqdm(data, desc="Downloading files"):
if "url" not in item:
logging.error("Couldn't find url for source in {}".format(item))
continue
if snippet_only and (item["local_name"])[-3:] in [
"zip",
".gz",
]: # Can't truncate compressed files
logging.error(
"Asked to download snippets; can't snippet {}".format(item)
)
outfile = os.path.join(output_dir, local_name)
continue

logging.info("Retrieving %s from %s" % (outfile, item["url"]))
local_name = (
item["local_name"]
if "local_name" in item and item["local_name"]
else item["url"].split("/")[-1]
)
outfile = os.path.join(output_dir, local_name)

logging.info("Retrieving %s from %s" % (outfile, item["url"]))

if "local_name" in item:
local_file_dir = os.path.join(
output_dir, os.path.dirname(item["local_name"])
if "local_name" in item:
local_file_dir = os.path.join(
output_dir, os.path.dirname(item["local_name"])
)
if not os.path.exists(local_file_dir):
logging.info(f"Creating local directory {local_file_dir}")
pathlib.Path(local_file_dir).mkdir(parents=True, exist_ok=True)

if os.path.exists(outfile):
if ignore_cache:
logging.info("Deleting cached version of {}".format(outfile))
os.remove(outfile)
else:
logging.info("Using cached version of {}".format(outfile))
continue

# Download file
if "api" in item:
download_from_api(item, outfile)
if "url" in item:
url = parse_url(item["url"])
if url.startswith("gs://"):
Blob.from_string(url, client=storage.Client()).download_to_filename(
outfile
)
elif url.startswith("s3://"):
s3 = boto3.client("s3")
bucket_name = url.split("/")[2]
remote_file = "/".join(url.split("/")[3:])
s3.download_file(bucket_name, remote_file, outfile)
elif url.startswith("ftp"):
glob = None
if "glob" in item:
glob = item["glob"]
ftp_username = (
os.getenv("FTP_USERNAME") if os.getenv("FTP_USERNAME") else None
)
ftp_password = (
os.getenv("FTP_PASSWORD") if os.getenv("FTP_PASSWORD") else None
)
if not os.path.exists(local_file_dir):
logging.info(f"Creating local directory {local_file_dir}")
pathlib.Path(local_file_dir).mkdir(parents=True, exist_ok=True)

if os.path.exists(outfile):
if ignore_cache:
logging.info("Deleting cached version of {}".format(outfile))
os.remove(outfile)
host = url.split("/")[0]
path = "/".join(url.split("/")[1:])
ftp = ftplib.FTP(host)
ftp.login(ftp_username, ftp_password)
download_via_ftp(ftp, path, outfile, glob)
elif any(
url.startswith(str(i))
for i in list(GDOWN_MAP.keys()) + list(GDOWN_MAP.values())
):
# Check if url starts with a key or a value
for key, value in GDOWN_MAP.items():
if url.startswith(str(value)):
# If value, then download the file directly
gdown.download(url, output=outfile)
break
elif url.startswith(str(key)):
# If key, replace key by value and then download
new_url = url.replace(str(key) + ":", str(value))
gdown.download(new_url, output=outfile)
break
else:
logging.info("Using cached version of {}".format(outfile))
continue

# Download file
if "api" in item:
download_from_api(item, outfile)
if "url" in item:
url = parse_url(item["url"])
if url.startswith("gs://"):
Blob.from_string(url, client=storage.Client()).download_to_filename(
outfile
# If the loop completes without breaking (i.e., no match found), throw an error
raise ValueError("Invalid URL")
elif url.startswith("git://"):
url_split = url.split("/")
repo_owner = url_split[-3]
repo_name = url_split[-2]
asset_name = url_split[-1]
asset_url = None
api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases"
# Get the list of releases
response = requests.get(api_url)
response.raise_for_status()
releases = response.json()

if not releases:
print("No releases found for this repository.")
sys.exit(1)

# Check if a specific tag is provided
if "tag" in item:
# Find the release with the specified tag
tagged_release = next(
(
release
for release in releases
if release["tag_name"] == item["tag"]
),
None,
)
elif url.startswith("s3://"):
s3 = boto3.client("s3")
bucket_name = url.split("/")[2]
remote_file = "/".join(url.split("/")[3:])
s3.download_file(bucket_name, remote_file, outfile)
elif url.startswith("ftp"):
glob = None
if "glob" in item:
glob = item["glob"]
ftp_username = (
os.getenv("FTP_USERNAME") if os.getenv("FTP_USERNAME") else None
)
ftp_password = (
os.getenv("FTP_PASSWORD") if os.getenv("FTP_PASSWORD") else None
)
host = url.split("/")[0]
path = "/".join(url.split("/")[1:])
ftp = ftplib.FTP(host)
ftp.login(ftp_username, ftp_password)
download_via_ftp(ftp, path, outfile, glob)
elif any(
url.startswith(str(i))
for i in list(GDOWN_MAP.keys()) + list(GDOWN_MAP.values())
):
# Check if url starts with a key or a value
for key, value in GDOWN_MAP.items():
if url.startswith(str(value)):
# If value, then download the file directly
gdown.download(url, output=outfile)
break
elif url.startswith(str(key)):
# If key, replace key by value and then download
new_url = url.replace(str(key) + ":", str(value))
gdown.download(new_url, output=outfile)
break
else:
# If the loop completes without breaking (i.e., no match found), throw an error
raise ValueError("Invalid URL")
elif url.startswith("git://"):
url_split = url.split("/")
repo_owner = url_split[-3]
repo_name = url_split[-2]
asset_name = url_split[-1]
asset_url = None
api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases"
# Get the list of releases
response = requests.get(api_url)
response.raise_for_status()
releases = response.json()

if not releases:
print("No releases found for this repository.")
sys.exit(1)

# Check if a specific tag is provided
if "tag" in item:
# Find the release with the specified tag
tagged_release = next(
(
release
for release in releases
if release["tag_name"] == item["tag"]
),
None,
)
if tagged_release:
for asset in tagged_release.get("assets", []):
if asset["name"] == asset_name:
asset_url = asset["browser_download_url"]
break

# If no asset found in the specified tag or no tag provided, check other releases
if not asset_url:
for release in releases:
for asset in release.get("assets", []):
if asset["name"] == asset_name:
asset_url = asset["browser_download_url"]
break
if asset_url:
if tagged_release:
for asset in tagged_release.get("assets", []):
if asset["name"] == asset_name:
asset_url = asset["browser_download_url"]
break

if not asset_url:
print(f"Asset '{asset_name}' not found in any release.")
sys.exit(1)
# If no asset found in the specified tag or no tag provided, check other releases
if not asset_url:
for release in releases:
for asset in release.get("assets", []):
if asset["name"] == asset_name:
asset_url = asset["browser_download_url"]
break
if asset_url:
break

# Download the asset
response = requests.get(asset_url, stream=True)
response.raise_for_status()
with open(outfile, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Downloaded {asset_name}")
if not asset_url:
print(f"Asset '{asset_name}' not found in any release.")
sys.exit(1)

else:
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
try:
with urlopen(req) as response, open(outfile, "wb") as out_file: # type: ignore
if snippet_only:
data = response.read(
5120
) # first 5 kB of a `bytes` object
else:
data = response.read() # a `bytes` object
out_file.write(data)
if snippet_only: # Need to clean up the outfile
in_file = open(outfile, "r+")
in_lines = in_file.read()
in_file.close()
splitlines = in_lines.split("\n")
outstring = "\n".join(splitlines[:-1])
cleanfile = open(outfile, "w+")
for i in range(len(outstring)):
cleanfile.write(outstring[i])
cleanfile.close()
except URLError:
logging.error(f"Failed to download: {url}")
raise

# If mirror, upload to remote storage
if mirror:
mirror_to_bucket(
local_file=outfile, bucket_url=mirror, remote_file=local_name
)
# Download the asset
response = requests.get(asset_url, stream=True)
response.raise_for_status()
with open(outfile, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Downloaded {asset_name}")

else:
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
try:
with urlopen(req) as response: # type: ignore
if snippet_only:
data = response.read(
5120
) # first 5 kB of a `bytes` object
else:
data = response.read() # a `bytes` object

with open(outfile, "wb") as out_file:
out_file.write(data)
if snippet_only: # Need to clean up the outfile
in_file = open(outfile, "r+")
in_lines = in_file.read()
in_file.close()
splitlines = in_lines.split("\n")
outstring = "\n".join(splitlines[:-1])
cleanfile = open(outfile, "w+")
for i in range(len(outstring)):
cleanfile.write(outstring[i])
cleanfile.close()
except URLError:
logging.error(f"Failed to download: {url}")
raise

# If mirror, upload to remote storage
if mirror:
mirror_to_bucket(
local_file=outfile, bucket_url=mirror, remote_file=local_name
)

return None

Expand Down Expand Up @@ -309,11 +312,11 @@ def download_from_api(yaml_item, outfile) -> None:
query_data = compress_json.local_load(
os.path.join(os.getcwd(), yaml_item["query_file"])
)
output = open(outfile, "w")
records = elastic_search_query(
es_conn, index=yaml_item["index"], query=query_data
)
json.dump(records, output)
with open(outfile, "w") as output:
json.dump(records, output)
return None
else:
raise RuntimeError(f"API {yaml_item['api']} not supported")
Expand Down

0 comments on commit 50d8506

Please sign in to comment.