Skip to content

Commit

Permalink
remove irods session context manager and multiple sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
sellth committed Nov 28, 2023
1 parent e1625a8 commit 1a58949
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 42 deletions.
36 changes: 11 additions & 25 deletions cubi_tk/irods_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from contextlib import contextmanager
import getpass
import os.path
from pathlib import Path
Expand All @@ -22,8 +21,6 @@
formatter = logzero.LogFormatter(fmt="%(message)s")
output_logger = logzero.setup_logger(formatter=formatter)

NUM_PARALLEL_SESSIONS = 4


@attrs.frozen(auto_attribs=True)
class TransferJob:
Expand Down Expand Up @@ -139,17 +136,6 @@ def _save_irods_token(self, token: str):
else:
logger.warning("No token found to be saved.")

@contextmanager
def _get_irods_sessions(self, count=NUM_PARALLEL_SESSIONS):
if count < 1:
count = 1
irods_sessions = [self._init_irods() for _ in range(count)]
try:
yield irods_sessions
finally:
for irods in irods_sessions:
irods.cleanup()

@property
def session(self):
return self._init_irods()
Expand Down Expand Up @@ -183,8 +169,8 @@ def destinations(self):

def _create_collections(self, job: TransferJob):
collection = str(Path(job.path_remote).parent)
with self._get_irods_sessions(1) as session:
session[0].collections.create(collection)
with self.session as session:
session.collections.create(collection)

def put(self, recursive: bool = False, sync: bool = False):
# Double tqdm for currently transferred file info
Expand All @@ -201,13 +187,13 @@ def put(self, recursive: bool = False, sync: bool = False):
f"File [{n + 1}/{len(self.__jobs)}]: {Path(job.path_local).name}"
)
try:
with self._get_irods_sessions(1) as session:
with self.session as session:
if recursive:
self._create_collections(job)
if sync and session[0].data_objects.exists(job.path_remote):
if sync and session.data_objects.exists(job.path_remote):
t.update(job.bytes)
continue
session[0].data_objects.put(job.path_local, job.path_remote)
session.data_objects.put(job.path_local, job.path_remote)
t.update(job.bytes)
except Exception as e: # pragma: no cover
logger.error(f"Problem during transfer of {job.path_local}")
Expand All @@ -222,8 +208,8 @@ def chksum(self):
if not job.path_local.endswith(".md5"):
output_logger.info(Path(job.path_remote).relative_to(common_prefix))
try:
with self._get_irods_sessions(1) as session:
data_object = session[0].data_objects.get(job.path_remote)
with self.session as session:
data_object = session.data_objects.get(job.path_remote)
if not data_object.checksum:
data_object.chksum()
except Exception as e: # pragma: no cover
Expand All @@ -232,9 +218,9 @@ def chksum(self):

def get(self):
"""Download files from SODAR."""
with self._get_irods_sessions(1) as session:
with self.session as session:
self.__jobs = [
attrs.evolve(job, bytes=session[0].data_objects.get(job.path_remote).size)
attrs.evolve(job, bytes=session.data_objects.get(job.path_remote).size)
for job in self.__jobs
]
self.__total_bytes = sum([job.bytes for job in self.__jobs])
Expand All @@ -252,8 +238,8 @@ def get(self):
f"File [{n + 1}/{len(self.__jobs)}]: {Path(job.path_local).name}"
)
try:
with self._get_irods_sessions(1) as session:
session[0].data_objects.get(job.path_remote, job.path_local)
with self.session as session:
session.data_objects.get(job.path_remote, job.path_local)
t.update(job.bytes)
except FileNotFoundError: # pragma: no cover
raise
Expand Down
28 changes: 11 additions & 17 deletions tests/test_irods_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,6 @@ def test_save_irods_token(mocksession, mockencode, fs):
mockencode.assert_called_with("secure")


@patch("cubi_tk.irods_common.iRODSSession")
def test_get_irods_sessions(mocksession):
with iRODSCommon()._get_irods_sessions(count=3) as sessions:
assert len(sessions) == 3
with iRODSCommon()._get_irods_sessions(count=-1) as sessions:
assert len(sessions) == 1


# Test iRODSTransfer #########
@pytest.fixture
def jobs():
Expand All @@ -95,15 +87,17 @@ def test_irods_transfer_init(jobs):
assert itransfer.destinations == [job.path_remote for job in jobs]


@patch("cubi_tk.irods_common.iRODSSession")
@patch("cubi_tk.irods_common.iRODSTransfer._init_irods")
@patch("cubi_tk.irods_common.iRODSTransfer._create_collections")
def test_irods_transfer_put(mockrecursive, mocksession, jobs):
mockput = MagicMock()
mockexists = MagicMock()
mockexists = MagicMock(return_value=True)
mockobj = MagicMock()
mockobj.put = mockput
mockobj.exists = mockexists
mocksession.return_value.data_objects = mockobj

# fit for context management
mocksession.return_value.__enter__.return_value.data_objects = mockobj
itransfer = iRODSTransfer(jobs)

# put
Expand All @@ -123,25 +117,25 @@ def test_irods_transfer_put(mockrecursive, mocksession, jobs):
mockexists.assert_called()


@patch("cubi_tk.irods_common.iRODSSession")
@patch("cubi_tk.irods_common.iRODSTransfer._init_irods")
def test_create_collections(mocksession, jobs):
mockcreate = MagicMock()
mockcoll = MagicMock()
mockcoll.create = mockcreate
mocksession.return_value.collections = mockcoll
mocksession.return_value.__enter__.return_value.collections = mockcoll
itransfer = iRODSTransfer(jobs)

itransfer._create_collections(itransfer.jobs[1])
coll_path = str(Path(itransfer.jobs[1].path_remote).parent)
mockcreate.assert_called_with(coll_path)


@patch("cubi_tk.irods_common.iRODSSession")
@patch("cubi_tk.irods_common.iRODSTransfer._init_irods")
def test_irods_transfer_chksum(mocksession, jobs):
mockget = MagicMock()
mockobj = MagicMock()
mockobj.get = mockget
mocksession.return_value.data_objects = mockobj
mocksession.return_value.__enter__.return_value.data_objects = mockobj

mock_data_object = MagicMock()
mock_data_object.checksum = None
Expand All @@ -156,12 +150,12 @@ def test_irods_transfer_chksum(mocksession, jobs):
mockget.assert_any_call(path)


@patch("cubi_tk.irods_common.iRODSSession")
@patch("cubi_tk.irods_common.iRODSTransfer._init_irods")
def test_irods_transfer_get(mocksession, jobs):
mockget = MagicMock()
mockobj = MagicMock()
mockobj.get = mockget
mocksession.return_value.data_objects = mockobj
mocksession.return_value.__enter__.return_value.data_objects = mockobj
itransfer = iRODSTransfer(jobs)

mockget.return_value.size = 111
Expand Down

0 comments on commit 1a58949

Please sign in to comment.