diff --git a/aiida_firecrest/transport.py b/aiida_firecrest/transport.py index 32558fc..ae1a35d 100644 --- a/aiida_firecrest/transport.py +++ b/aiida_firecrest/transport.py @@ -116,6 +116,17 @@ class FirecrestTransport(Transport): "callback": validate_positive_number, }, ), + ( + "file_transfer_poll_interval", + { + "type": float, + "default": 0.1, # TODO what default to choose? + "non_interactive_default": True, + "prompt": "File transfer poll interval (s)", + "help": "Poll interval when waiting for large file transfers.", + "callback": validate_positive_number, + }, + ), ] def __init__( @@ -126,13 +137,17 @@ def __init__( client_id: str, client_secret: str | Path, client_machine: str, - small_file_size_mb: float, + small_file_size_mb: float = 5.0, + file_transfer_poll_interval: float = 0.1, # note, machine is provided by default, # for the hostname, but we don't use that # TODO ideally hostname would not be necessary on a computer **kwargs: Any, ): """Construct a FirecREST transport.""" + # there is no overhead for "opening" a connection to a REST-API, + # but still allow the user to set a safe interval if they really want to + kwargs.setdefault("safe_interval", 0) super().__init__(**kwargs) # type: ignore assert isinstance(url, str), "url must be a string" @@ -146,12 +161,16 @@ def __init__( assert isinstance( small_file_size_mb, float ), "small_file_size_mb must be a float" + assert isinstance( + file_transfer_poll_interval, float + ), "file_transfer_poll_interval must be a float" self._machine = client_machine self._url = url self._token_uri = token_uri self._client_id = client_id self._small_file_size_bytes = int(small_file_size_mb * 1024 * 1024) + self._file_transfer_poll_interval = file_transfer_poll_interval secret = ( client_secret.read_text() @@ -349,9 +368,8 @@ def getfile( # this waits for the file to be moved to the staging area # TODO handle the transfer stalling (timeout?) and optimise the polling interval - # (and allow configurability?) while down_obj.in_progress: - time.sleep(0.1) + time.sleep(self._file_transfer_poll_interval) # this downloads the file from the "staging area" url = down_obj.object_storage_data @@ -443,9 +461,8 @@ def putfile( up_obj.finish_upload() # this waits for the file in the staging area to be moved to the final location # TODO handle the transfer stalling (timeout?) and optimise the polling interval - # (and allow configurability?) while up_obj.in_progress: - time.sleep(0.1) + time.sleep(self._file_transfer_poll_interval) # TODO use cwd.checksum to confirm upload is not corrupted?