Skip to content

Commit

Permalink
fixup! Fix duplicate downloads (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
jessebrennan committed Nov 23, 2019
1 parent 4e3ef4e commit 2d0c1c1
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
11 changes: 5 additions & 6 deletions hca/dss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from io import open

import requests
from atomicwrites import atomic_write
from requests.exceptions import ChunkedEncodingError, ConnectionError, ReadTimeout

from hca.dss.util import iter_paths, object_name_builder, hardlink, atomic_write
from hca.dss.util import iter_paths, object_name_builder, hardlink, atomic_overwrite
from glob import escape as glob_escape
from hca.util import tsv
from ..util import SwaggerClient, DEFAULT_THREAD_COUNT
Expand Down Expand Up @@ -527,7 +528,7 @@ def _download_bundle_manifest(self, manifest_bytes, bundle_dir, dss_file):
logger.info("Skipping download of '%s' because it already exists at '%s'.", dss_file.name, dest_path)
else:
self._make_dirs_if_necessary(dest_path)
with atomic_write(dest_path, mode="wb") as fh:
with atomic_overwrite(dest_path, mode="wb") as fh:
fh.write(manifest_bytes)
file_path = os.path.join(bundle_dir, dss_file.name)
self._make_dirs_if_necessary(file_path)
Expand Down Expand Up @@ -591,7 +592,7 @@ def _download_file(self, dss_file, dest_path):
ranged get doesn't yield the correct header, then we start over.
"""
self._make_dirs_if_necessary(dest_path)
with atomic_write(dest_path, mode="wb") as fh:
with atomic_overwrite(dest_path, mode="wb") as fh:
if dss_file.size == 0:
return

Expand Down Expand Up @@ -759,9 +760,7 @@ def _write_output_manifest(self):
fieldnames, source_manifest = self._parse_manifest(self.manifest)
if 'file_path' not in fieldnames:
fieldnames.append('file_path')
if os.path.isfile(output):
logger.warning('Overwriting manifest %s', output)
with open(output, mode='w', newline='') as f:
with atomic_write(output, overwrite=True, newline='') as f:
writer = tsv.DictWriter(f, fieldnames)
writer.writeheader()
for row in source_manifest:
Expand Down
3 changes: 1 addition & 2 deletions hca/dss/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import os
from builtins import FileExistsError

Expand Down Expand Up @@ -57,7 +56,7 @@ def hardlink(source, link_name):
raise


class atomic_write:
class atomic_overwrite:
"""Atomically write, but don't complain if file already exists"""

def __init__(self, *args, **kwargs):
Expand Down
20 changes: 10 additions & 10 deletions test/unit/test_dss_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _fake_download_file(*args, **kwargs):
_touch_file(args[1])


def _fake_paginate_factory(fake_hash=False):
def _make_fake_paginate(fake_hash=False):
def _fake_get_bundle_paginate(*args, **kwargs):
bundle_dict = {
'version': '1_version',
Expand Down Expand Up @@ -88,7 +88,7 @@ def _fake_do_download_file_with_barrier(*args, **kwargs):
fh.write(b'Here we write some stuff so that the fake download takes some time. '
b'This helps ensure that multiple threads are writing at once and thus '
b'allows us to test for race conditions.')
time.sleep(random.random()) # sleep for some small random amount of time
time.sleep(random.random())
return 'FAKEhash'


Expand Down Expand Up @@ -318,7 +318,7 @@ def _assert_all_files_downloaded(self, more_files=None, prefix=''):
@patch('hca.dss.DSSClient.get_bundle')
@patch('hca.dss.DownloadContext._download_file', side_effect=_fake_download_file)
def test_manifest_download_bundle(self, _, mock_get_bundle):
mock_get_bundle.paginate = _fake_paginate_factory()
mock_get_bundle.paginate = _make_fake_paginate()
self.dss.download_manifest(self.manifest_file, 'aws', layout='bundle')
self._assert_all_files_downloaded()
self.dss.download_manifest(self.manifest_file, 'aws', layout='bundle')
Expand All @@ -329,7 +329,7 @@ def test_manifest_download_bundle(self, _, mock_get_bundle):
def _test_download_dir(self, download_dir):
with patch('hca.dss.DownloadContext._download_file', side_effect=_fake_download_file), \
patch('hca.dss.DSSClient.get_bundle') as mock_get_bundle:
mock_get_bundle.paginate = _fake_paginate_factory()
mock_get_bundle.paginate = _make_fake_paginate()
self.dss.download_manifest(self.manifest_file, 'aws', layout='bundle', download_dir=download_dir)
self._assert_all_files_downloaded(prefix=download_dir)
self.dss.download_manifest(self.manifest_file, 'aws', layout='bundle', download_dir=download_dir)
Expand Down Expand Up @@ -365,7 +365,7 @@ def test_manifest_download_bad_file(self, _, mock_get_bundle):
Ensure error is raised if a user created file has the same name as the one
we're trying to download.
"""
mock_get_bundle.paginate = _fake_paginate_factory()
mock_get_bundle.paginate = _make_fake_paginate()
manifest_directory = self.manifest[1][0] + '.' + self.manifest[1][1]
_touch_file(os.path.join(manifest_directory, self.manifest[1][3]))
self.assertRaises(RuntimeError, self.dss.download_manifest, self.manifest_file, 'aws', layout='bundle')
Expand All @@ -381,7 +381,7 @@ def test_manifest_download_bundle_parallel(self, mock_get_bundle):
"""
random.seed('same seed for consistency')
self._write_uniform_manifest()
mock_get_bundle.paginate = _fake_paginate_factory(fake_hash=True)
mock_get_bundle.paginate = _make_fake_paginate(fake_hash=True)
with patch('hca.dss.DownloadContext._do_download_file', side_effect=_fake_do_download_file_with_barrier):
# 3 threads for three files with barrier size 3
with patch('hca.dss.TaskRunner', return_value=TaskRunner(threads=3)):
Expand Down Expand Up @@ -410,7 +410,7 @@ class TestDownload(AbstractTestDSSClient):
@patch('hca.dss.DSSClient.get_bundle')
@patch('hca.dss.DownloadContext._download_file', side_effect=_fake_download_file)
def test_download(self, _, mock_get_bundle):
mock_get_bundle.paginate = _fake_paginate_factory()
mock_get_bundle.paginate = _make_fake_paginate()
self.dss.download('any_bundle_uuid', 'aws')
more_files = {os.path.join('.', 'any_bundle_uuid', file_name)
for file_name in ['a_file_name', 'b_file_name', 'c_file_name', 'metadata_file.pdf']}
Expand Down Expand Up @@ -440,7 +440,7 @@ def _test_download_filters(self, no_metadata, no_data):
all_files = metadata_files.union(data_files)
with patch('hca.dss.DSSClient.get_bundle') as mock_get_bundle, \
patch('hca.dss.DownloadContext._download_file', side_effect=_fake_download_file):
mock_get_bundle.paginate = _fake_paginate_factory()
mock_get_bundle.paginate = _make_fake_paginate()
self.dss.download('any_bundle_uuid', 'aws', no_metadata=no_metadata, no_data=no_data)
expected_files = all_files
if no_data:
Expand All @@ -464,15 +464,15 @@ def test_download_filters_conflict(self):
@patch('logging.Logger.warning')
@patch('hca.dss.DownloadContext._download_file', side_effect=[None, ValueError(), KeyError()])
def test_manifest_download_failed(self, _, warning_log, mock_get_bundle):
mock_get_bundle.paginate = _fake_paginate_factory()
mock_get_bundle.paginate = _make_fake_paginate()
self.assertRaises(RuntimeError, self.dss.download, 'any_bundle_uuid', 'aws')
self.assertEqual(warning_log.call_count, 4)
self._assert_manifest_not_updated()

def _test_download_dir(self, download_dir):
with patch('hca.dss.DownloadContext._download_file', side_effect=_fake_download_file), \
patch('hca.dss.DSSClient.get_bundle') as mock_get_bundle:
mock_get_bundle.paginate = _fake_paginate_factory()
mock_get_bundle.paginate = _make_fake_paginate()
self.dss.download('any_bundle_uuid', 'aws')
more_files = {os.path.join(download_dir, 'any_bundle_uuid', file_name)
for file_name in ['a_file_name', 'b_file_name', 'c_file_name', 'metadata_file.pdf']}
Expand Down

0 comments on commit 2d0c1c1

Please sign in to comment.