diff --git a/.readthedocs.yml b/.readthedocs.yml index 8f1e3118d0..6b64320875 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,7 +20,7 @@ build: - asdf install uv 0.2.9 - asdf global uv 0.2.9 post_install: - - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv pip install .[docs,tests,rest,atomic_tools] + - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv pip install .[docs,tests,rest,atomic_tools] --preview # Let the build fail if there are any warnings sphinx: diff --git a/docs/source/topics/transport.rst b/docs/source/topics/transport.rst index 45706b94a7..d95cbeaac2 100644 --- a/docs/source/topics/transport.rst +++ b/docs/source/topics/transport.rst @@ -24,15 +24,15 @@ The generic transport class contains a set of minimal methods that an implementa If not, a ``NotImplementedError`` will be raised, interrupting the managing of the calculation or whatever is using the transport plugin. As for the general functioning of the plugin, the :py:meth:`~aiida.transports.transport.Transport.__init__` method is used only to initialize the class instance, without actually opening the transport channel. -The connection must be opened only by the :py:meth:`~aiida.transports.transport.Transport.__enter__` method, (and closed by :py:meth:`~aiida.transports.transport.Transport.__exit__`). -The :py:meth:`~aiida.transports.transport.Transport.__enter__` method lets you use the transport class using the ``with`` statement (see `python docs `_), in a way similar to the following: +The connection must be opened only by the :py:meth:`~aiida.transports.transport._BaseTransport.__enter__` method, (and closed by :py:meth:`~aiida.transports.transport._BaseTransport.__exit__`). +The :py:meth:`~aiida.transports.transport._BaseTransport.__enter__` method lets you use the transport class using the ``with`` statement (see `python docs `_), in a way similar to the following: .. code-block:: python with TransportPlugin() as transport: transport.some_method() -To ensure this, for example, the local plugin uses a hidden boolean variable ``_is_open`` that is set when the :py:meth:`~aiida.transports.transport.Transport.__enter__` and :py:meth:`~aiida.transports.transport.Transport.__exit__` methods are called. +To ensure this, for example, the local plugin uses a hidden boolean variable ``_is_open`` that is set when the :py:meth:`~aiida.transports.transport._BaseTransport.__enter__` and :py:meth:`~aiida.transports.transport._BaseTransport.__exit__` methods are called. The ``ssh`` logic is instead given by the property sftp. The other functions that require some care are the copying functions, called using the following terminology: diff --git a/environment.yml b/environment.yml index ad80dd3416..9ce52da6f5 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - python~=3.9 - alembic~=1.2 - archive-path~=0.4.2 +- asyncssh~=2.19.0 - circus~=0.18.0 - click-spinner~=0.1.8 - click~=8.1 @@ -22,7 +23,7 @@ dependencies: - importlib-metadata~=6.0 - numpy~=1.21 - paramiko~=3.0 -- plumpy~=0.22.3 +- plumpy - pgsu~=0.3.0 - psutil~=5.6 - psycopg[binary]~=3.0 diff --git a/pyproject.toml b/pyproject.toml index 9461b7a46a..073c5414f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ dependencies = [ 'alembic~=1.2', 'archive-path~=0.4.2', + "asyncssh~=2.19.0", 'circus~=0.18.0', 'click-spinner~=0.1.8', 'click~=8.1', @@ -34,7 +35,7 @@ dependencies = [ 'importlib-metadata~=6.0', 'numpy~=1.21', 'paramiko~=3.0', - 'plumpy~=0.22.3', + "plumpy", 'pgsu~=0.3.0', 'psutil~=5.6', 'psycopg[binary]~=3.0', @@ -175,6 +176,7 @@ requires-python = '>=3.9' [project.entry-points.'aiida.transports'] 'core.local' = 'aiida.transports.plugins.local:LocalTransport' 'core.ssh' = 'aiida.transports.plugins.ssh:SshTransport' +'core.ssh_async' = 'aiida.transports.plugins.ssh_async:AsyncSshTransport' 'core.ssh_auto' = 'aiida.transports.plugins.ssh_auto:SshAutoTransport' [project.entry-points.'aiida.workflows'] @@ -308,6 +310,7 @@ module = 'tests.*' ignore_missing_imports = true module = [ 'ase.*', + 'asyncssh.*', 'bpython.*', 'bs4.*', 'CifFile.*', @@ -509,3 +512,7 @@ passenv = AIIDA_TEST_WORKERS commands = molecule {posargs:test} """ + +[tool.uv.sources] +asyncssh = {git = "https://github.com/ronf/asyncssh", branch = "develop"} +plumpy = {git = "https://github.com/aiidateam/plumpy", rev = "4611154c76ac0991bcf7371b21488f4390648c28"} diff --git a/src/aiida/calculations/monitors/base.py b/src/aiida/calculations/monitors/base.py index 459f4eba9d..87d5054915 100644 --- a/src/aiida/calculations/monitors/base.py +++ b/src/aiida/calculations/monitors/base.py @@ -4,12 +4,13 @@ import tempfile from pathlib import Path +from typing import Union from aiida.orm import CalcJobNode -from aiida.transports import Transport +from aiida.transports import AsyncTransport, Transport -def always_kill(node: CalcJobNode, transport: Transport) -> str | None: +def always_kill(node: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> str | None: """Retrieve and inspect files in working directory of job to determine whether the job should be killed. This particular implementation is just for demonstration purposes and will kill the job as long as there is a diff --git a/src/aiida/engine/daemon/execmanager.py b/src/aiida/engine/daemon/execmanager.py index 73c30cab61..ece32173f9 100644 --- a/src/aiida/engine/daemon/execmanager.py +++ b/src/aiida/engine/daemon/execmanager.py @@ -35,7 +35,7 @@ from aiida.schedulers.datastructures import JobState if TYPE_CHECKING: - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found' @@ -62,9 +62,9 @@ def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]: return data_node -def upload_calculation( +async def upload_calculation( node: CalcJobNode, - transport: Transport, + transport: Union['Transport', 'AsyncTransport'], calc_info: CalcInfo, folder: Folder, inputs: Optional[MappingType[str, Any]] = None, @@ -105,7 +105,7 @@ def upload_calculation( if dry_run: workdir = Path(folder.abspath) else: - remote_user = transport.whoami() + remote_user = await transport.whoami_async() remote_working_directory = computer.get_workdir().format(username=remote_user) if not remote_working_directory.strip(): raise exceptions.ConfigurationError( @@ -114,13 +114,13 @@ def upload_calculation( ) # If it already exists, no exception is raised - if not transport.path_exists(remote_working_directory): + if not await transport.path_exists_async(remote_working_directory): logger.debug( f'[submission of calculation {node.pk}] Path ' f'{remote_working_directory} does not exist, trying to create it' ) try: - transport.makedirs(remote_working_directory) + await transport.makedirs_async(remote_working_directory) except EnvironmentError as exc: raise exceptions.ConfigurationError( f'[submission of calculation {node.pk}] ' @@ -133,14 +133,14 @@ def upload_calculation( # and I do not have to know the logic, but I just need to # read the absolute path from the calculation properties. workdir = Path(remote_working_directory).joinpath(calc_info.uuid[:2], calc_info.uuid[2:4]) - transport.makedirs(str(workdir), ignore_existing=True) + await transport.makedirs_async(workdir, ignore_existing=True) try: # The final directory may already exist, most likely because this function was already executed once, but # failed and as a result was rescheduled by the engine. In this case it would be fine to delete the folder # and create it from scratch, except that we cannot be sure that this the actual case. Therefore, to err on # the safe side, we move the folder to the lost+found directory before recreating the folder from scratch - transport.mkdir(str(workdir.joinpath(calc_info.uuid[4:]))) + await transport.mkdir_async(workdir.joinpath(calc_info.uuid[4:])) except OSError: # Move the existing directory to lost+found, log a warning and create a clean directory anyway path_existing = os.path.join(str(workdir), calc_info.uuid[4:]) @@ -151,12 +151,12 @@ def upload_calculation( ) # Make sure the lost+found directory exists, then copy the existing folder there and delete the original - transport.mkdir(path_lost_found, ignore_existing=True) - transport.copytree(path_existing, path_target) - transport.rmtree(path_existing) + await transport.mkdir_async(path_lost_found, ignore_existing=True) + await transport.copytree_async(path_existing, path_target) + await transport.rmtree_async(path_existing) # Now we can create a clean folder for this calculation - transport.mkdir(str(workdir.joinpath(calc_info.uuid[4:]))) + await transport.mkdir_async(workdir.joinpath(calc_info.uuid[4:])) finally: workdir = workdir.joinpath(calc_info.uuid[4:]) @@ -171,11 +171,11 @@ def upload_calculation( # Note: this will possibly overwrite files for root, dirnames, filenames in code.base.repository.walk(): # mkdir of root - transport.makedirs(str(workdir.joinpath(root)), ignore_existing=True) + await transport.makedirs_async(workdir.joinpath(root), ignore_existing=True) # remotely mkdir first for dirname in dirnames: - transport.makedirs(str(workdir.joinpath(root, dirname)), ignore_existing=True) + await transport.makedirs_async(workdir.joinpath(root, dirname), ignore_existing=True) # Note, once #2579 is implemented, use the `node.open` method instead of the named temporary file in # combination with the new `Transport.put_object_from_filelike` @@ -185,11 +185,11 @@ def upload_calculation( content = code.base.repository.get_object_content(Path(root) / filename, mode='rb') handle.write(content) handle.flush() - transport.put(handle.name, str(workdir.joinpath(root, filename))) + await transport.put_async(handle.name, workdir.joinpath(root, filename)) if code.filepath_executable.is_absolute(): - transport.chmod(str(code.filepath_executable), 0o755) # rwxr-xr-x + await transport.chmod_async(code.filepath_executable, 0o755) # rwxr-xr-x else: - transport.chmod(str(workdir.joinpath(code.filepath_executable)), 0o755) # rwxr-xr-x + await transport.chmod_async(workdir.joinpath(code.filepath_executable), 0o755) # rwxr-xr-x # local_copy_list is a list of tuples, each with (uuid, dest_path, rel_path) # NOTE: validation of these lists are done inside calculation.presubmit() @@ -206,15 +206,15 @@ def upload_calculation( for file_copy_operation in file_copy_operation_order: if file_copy_operation is FileCopyOperation.LOCAL: - _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir=workdir) + await _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir=workdir) elif file_copy_operation is FileCopyOperation.REMOTE: if not dry_run: - _copy_remote_files( + await _copy_remote_files( logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir=workdir ) elif file_copy_operation is FileCopyOperation.SANDBOX: if not dry_run: - _copy_sandbox_files(logger, node, transport, folder, workdir=workdir) + await _copy_sandbox_files(logger, node, transport, folder, workdir=workdir) else: raise RuntimeError(f'file copy operation {file_copy_operation} is not yet implemented.') @@ -279,7 +279,7 @@ def upload_calculation( return None -def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir: Path): +async def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir: Path): """Perform the copy instructions of the ``remote_copy_list`` and ``remote_symlink_list``.""" for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list: if remote_computer_uuid == computer.uuid: @@ -288,7 +288,7 @@ def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remo f'remotely, directly on the machine {computer.label}' ) try: - transport.copy(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + await transport.copy_async(remote_abs_path, workdir.joinpath(dest_rel_path)) except FileNotFoundError: logger.warning( f'[submission of calculation {node.pk}] Unable to copy remote ' @@ -314,8 +314,8 @@ def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remo ) remote_dirname = Path(dest_rel_path).parent try: - transport.makedirs(str(workdir.joinpath(remote_dirname)), ignore_existing=True) - transport.symlink(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + await transport.makedirs_async(workdir.joinpath(remote_dirname), ignore_existing=True) + await transport.symlink_async(remote_abs_path, workdir.joinpath(dest_rel_path)) except OSError: logger.warning( f'[submission of calculation {node.pk}] Unable to create remote symlink ' @@ -328,7 +328,7 @@ def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remo ) -def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: Path): +async def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: Path): """Perform the copy instructions of the ``local_copy_list``.""" for uuid, filename, target in local_copy_list: logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}') @@ -356,14 +356,14 @@ def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: # The logic below takes care of an edge case where the source is a file but the target is a directory. In # this case, the v2.5.1 implementation would raise an `IsADirectoryError` exception, because it would try # to open the directory in the sandbox folder as a file when writing the contents. - if file_type_source == FileType.FILE and target and transport.isdir(str(workdir.joinpath(target))): + if file_type_source == FileType.FILE and target and await transport.isdir_async(workdir.joinpath(target)): raise IsADirectoryError # In case the source filename is specified and it is a directory that already exists in the remote, we # want to avoid nested directories in the target path to replicate the behavior of v2.5.1. This is done by # setting the target filename to '.', which means the contents of the node will be copied in the top level # of the temporary directory, whose contents are then copied into the target directory. - if filename and transport.isdir(str(workdir.joinpath(filename))): + if filename and await transport.isdir_async(workdir.joinpath(filename)): filename_target = '.' filepath_target = (dirpath / filename_target).resolve().absolute() @@ -372,9 +372,9 @@ def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: if file_type_source == FileType.DIRECTORY: # If the source object is a directory, we copy its entire contents data_node.base.repository.copy_tree(filepath_target, filename_source) - transport.put( + await transport.put_async( f'{dirpath}/*', - str(workdir.joinpath(target)) if target else str(workdir.joinpath('.')), + workdir.joinpath(target) if target else workdir.joinpath('.'), overwrite=True, ) else: @@ -382,18 +382,18 @@ def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: with filepath_target.open('wb') as handle: with data_node.base.repository.open(filename_source, 'rb') as source: shutil.copyfileobj(source, handle) - transport.makedirs(str(workdir.joinpath(Path(target).parent)), ignore_existing=True) - transport.put(str(filepath_target), str(workdir.joinpath(target))) + await transport.makedirs_async(workdir.joinpath(Path(target).parent), ignore_existing=True) + await transport.put_async(filepath_target, workdir.joinpath(target)) -def _copy_sandbox_files(logger, node, transport, folder, workdir: Path): +async def _copy_sandbox_files(logger, node, transport, folder, workdir: Path): """Copy the contents of the sandbox folder to the working directory.""" for filename in folder.get_content_list(): logger.debug(f'[submission of calculation {node.pk}] copying file/folder {filename}...') - transport.put(folder.get_abs_path(filename), str(workdir.joinpath(filename))) + await transport.put_async(folder.get_abs_path(filename), workdir.joinpath(filename)) -def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | ExitCode: +def submit_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> str | ExitCode: """Submit a previously uploaded `CalcJob` to the scheduler. :param calculation: the instance of CalcJobNode to submit. @@ -423,7 +423,7 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | return result -def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: +async def stash_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> None: """Stash files from the working directory of a completed calculation to a permanent remote folder. After a calculation has been completed, optionally stash files from the work directory to a storage location on the @@ -461,7 +461,7 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: for source_filename in source_list: if transport.has_magic(source_filename): copy_instructions = [] - for globbed_filename in transport.glob(str(source_basepath / source_filename)): + for globbed_filename in await transport.glob_async(source_basepath / source_filename): target_filepath = target_basepath / Path(globbed_filename).relative_to(source_basepath) copy_instructions.append((globbed_filename, target_filepath)) else: @@ -470,10 +470,10 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: for source_filepath, target_filepath in copy_instructions: # If the source file is in a (nested) directory, create those directories first in the target directory target_dirname = target_filepath.parent - transport.makedirs(str(target_dirname), ignore_existing=True) + await transport.makedirs_async(target_dirname, ignore_existing=True) try: - transport.copy(str(source_filepath), str(target_filepath)) + await transport.copy_async(source_filepath, target_filepath) except (OSError, ValueError) as exception: EXEC_LOGGER.warning(f'failed to stash {source_filepath} to {target_filepath}: {exception}') else: @@ -488,8 +488,8 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: remote_stash.base.links.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash') -def retrieve_calculation( - calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str +async def retrieve_calculation( + calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport'], retrieved_temporary_folder: str ) -> FolderData | None: """Retrieve all the files of a completed job calculation using the given transport. @@ -529,14 +529,14 @@ def retrieve_calculation( retrieve_temporary_list = calculation.get_retrieve_temporary_list() with SandboxFolder(filepath_sandbox) as folder: - retrieve_files_from_list(calculation, transport, folder.abspath, retrieve_list) + await retrieve_files_from_list(calculation, transport, folder.abspath, retrieve_list) # Here I retrieved everything; now I store them inside the calculation retrieved_files.base.repository.put_object_from_tree(folder.abspath) # Retrieve the temporary files in the retrieved_temporary_folder if any files were # specified in the 'retrieve_temporary_list' key if retrieve_temporary_list: - retrieve_files_from_list(calculation, transport, retrieved_temporary_folder, retrieve_temporary_list) + await retrieve_files_from_list(calculation, transport, retrieved_temporary_folder, retrieve_temporary_list) # Log the files that were retrieved in the temporary folder for filename in os.listdir(retrieved_temporary_folder): @@ -554,7 +554,7 @@ def retrieve_calculation( return retrieved_files -def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: +def kill_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> None: """Kill the calculation through the scheduler :param calculation: the instance of CalcJobNode to kill. @@ -587,9 +587,9 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: ) -def retrieve_files_from_list( +async def retrieve_files_from_list( calculation: CalcJobNode, - transport: Transport, + transport: Union['Transport', 'AsyncTransport'], folder: str, retrieve_list: List[Union[str, Tuple[str, str, int], list]], ) -> None: @@ -612,7 +612,7 @@ def retrieve_files_from_list( upto what level of the original remotepath nesting the files will be copied. :param transport: the Transport instance. - :param folder: an absolute path to a folder that contains the files to copy. + :param folder: an absolute path to a folder that contains the files to retrieve. :param retrieve_list: the list of files to retrieve. """ workdir = Path(calculation.get_remote_workdir()) @@ -621,7 +621,7 @@ def retrieve_files_from_list( tmp_rname, tmp_lname, depth = item # if there are more than one file I do something differently if transport.has_magic(tmp_rname): - remote_names = transport.glob(str(workdir.joinpath(tmp_rname))) + remote_names = await transport.glob_async(workdir.joinpath(tmp_rname)) local_names = [] for rem in remote_names: # get the relative path so to make local_names relative @@ -644,7 +644,7 @@ def retrieve_files_from_list( abs_item = item if item.startswith('/') else str(workdir.joinpath(item)) if transport.has_magic(abs_item): - remote_names = transport.glob(abs_item) + remote_names = await transport.glob_async(abs_item) local_names = [os.path.split(rem)[1] for rem in remote_names] else: remote_names = [abs_item] @@ -656,6 +656,6 @@ def retrieve_files_from_list( if rem.startswith('/'): to_get = rem else: - to_get = str(workdir.joinpath(rem)) + to_get = workdir.joinpath(rem) - transport.get(to_get, os.path.join(folder, loc), ignore_nonexisting=True) + await transport.get_async(to_get, os.path.join(folder, loc), ignore_nonexisting=True) diff --git a/src/aiida/engine/processes/calcjobs/calcjob.py b/src/aiida/engine/processes/calcjobs/calcjob.py index 8ced783a5f..8133ca1c8f 100644 --- a/src/aiida/engine/processes/calcjobs/calcjob.py +++ b/src/aiida/engine/processes/calcjobs/calcjob.py @@ -524,7 +524,7 @@ def on_terminated(self) -> None: super().on_terminated() @override - def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]: + async def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]: """Run the calculation job. This means invoking the `presubmit` and storing the temporary folder in the node's repository. Then we move the @@ -535,11 +535,11 @@ def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wa """ if self.inputs.metadata.dry_run: - self._perform_dry_run() + await self._perform_dry_run() return plumpy.process_states.Stop(None, True) if 'remote_folder' in self.inputs: - exit_code = self._perform_import() + exit_code = await self._perform_import() return exit_code # The following conditional is required for the caching to properly work. Even if the source node has a process @@ -627,7 +627,7 @@ def _setup_inputs(self) -> None: if not self.node.computer: self.node.computer = self.inputs.code.computer - def _perform_dry_run(self): + async def _perform_dry_run(self): """Perform a dry run. Instead of performing the normal sequence of steps, just the `presubmit` is called, which will call the method @@ -643,13 +643,13 @@ def _perform_dry_run(self): with LocalTransport() as transport: with SubmitTestFolder() as folder: calc_info = self.presubmit(folder) - upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) + await upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) self.node.dry_run_info = { # type: ignore[attr-defined] 'folder': folder.abspath, 'script_filename': self.node.get_option('submit_script_filename'), } - def _perform_import(self): + async def _perform_import(self): """Perform the import of an already completed calculation. The inputs contained a `RemoteData` under the key `remote_folder` signalling that this is not supposed to be run @@ -669,7 +669,7 @@ def _perform_import(self): with SandboxFolder(filepath_sandbox) as retrieved_temporary_folder: self.presubmit(folder) self.node.set_remote_workdir(self.inputs.remote_folder.get_remote_path()) - retrieved = retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) + retrieved = await retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) if retrieved is not None: self.out(self.node.link_label_retrieved, retrieved) self.update_outputs() diff --git a/src/aiida/engine/processes/calcjobs/monitors.py b/src/aiida/engine/processes/calcjobs/monitors.py index 507122ff1e..e13f01a5f3 100644 --- a/src/aiida/engine/processes/calcjobs/monitors.py +++ b/src/aiida/engine/processes/calcjobs/monitors.py @@ -8,6 +8,7 @@ import inspect import typing as t from datetime import datetime, timedelta +from typing import Union from aiida.common.lang import type_check from aiida.common.log import AIIDA_LOGGER @@ -15,7 +16,7 @@ from aiida.plugins import BaseFactory if t.TYPE_CHECKING: - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport LOGGER = AIIDA_LOGGER.getChild(__name__) @@ -122,7 +123,9 @@ def validate(self): parameters = list(signature.parameters.keys()) if any(required_parameter not in parameters for required_parameter in ('node', 'transport')): - correct_signature = '(node: CalcJobNode, transport: Transport, **kwargs) str | None:' + correct_signature = ( + "(node: CalcJobNode, transport: Union['Transport', 'AsyncTransport'], **kwargs) str | None:" + ) raise ValueError( f'The monitor `{self.entry_point}` has an invalid function signature, it should be: {correct_signature}' ) @@ -176,7 +179,7 @@ def monitors(self) -> collections.OrderedDict: def process( self, node: CalcJobNode, - transport: Transport, + transport: Union['Transport', 'AsyncTransport'], ) -> CalcJobMonitorResult | None: """Call all monitors in order and return the result as one returns anything other than ``None``. diff --git a/src/aiida/engine/processes/calcjobs/tasks.py b/src/aiida/engine/processes/calcjobs/tasks.py index 1059d277ba..80617e3bfd 100644 --- a/src/aiida/engine/processes/calcjobs/tasks.py +++ b/src/aiida/engine/processes/calcjobs/tasks.py @@ -92,7 +92,7 @@ async def do_upload(): except Exception as exception: raise PreSubmitException('exception occurred in presubmit call') from exception else: - remote_folder = execmanager.upload_calculation(node, transport, calc_info, folder) + remote_folder = await execmanager.upload_calculation(node, transport, calc_info, folder) if remote_folder is not None: process.out('remote_folder', remote_folder) skip_submit = calc_info.skip_submit or False @@ -314,7 +314,7 @@ async def do_retrieve(): if node.get_job_id() is None: logger.warning(f'there is no job id for CalcJobNoe<{node.pk}>: skipping `get_detailed_job_info`') - retrieved = execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) + retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) else: try: detailed_job_info = scheduler.get_detailed_job_info(node.get_job_id()) @@ -324,7 +324,7 @@ async def do_retrieve(): else: node.set_detailed_job_info(detailed_job_info) - retrieved = execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) + retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) if retrieved is not None: process.out(node.link_label_retrieved, retrieved) @@ -376,7 +376,7 @@ async def do_stash(): transport = await cancellable.with_interrupt(request) logger.info(f'stashing calculation<{node.pk}>') - return execmanager.stash_calculation(node, transport) + return await execmanager.stash_calculation(node, transport) try: await exponential_backoff_retry( diff --git a/src/aiida/engine/processes/functions.py b/src/aiida/engine/processes/functions.py index 8bca68f55c..2f8c363e1a 100644 --- a/src/aiida/engine/processes/functions.py +++ b/src/aiida/engine/processes/functions.py @@ -567,7 +567,7 @@ def _setup_db_record(self) -> None: self.node.store_source_info(self._func) @override - def run(self) -> 'ExitCode' | None: + async def run(self) -> 'ExitCode' | None: """Run the process.""" from .exit_code import ExitCode diff --git a/src/aiida/engine/processes/workchains/workchain.py b/src/aiida/engine/processes/workchains/workchain.py index 8818db5eb8..4b847722a0 100644 --- a/src/aiida/engine/processes/workchains/workchain.py +++ b/src/aiida/engine/processes/workchains/workchain.py @@ -297,7 +297,7 @@ def _update_process_status(self) -> None: @override @Protect.final - def run(self) -> t.Any: + async def run(self) -> t.Any: self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type] return self._do_step() diff --git a/src/aiida/engine/transports.py b/src/aiida/engine/transports.py index fe32df7884..cade4c04ca 100644 --- a/src/aiida/engine/transports.py +++ b/src/aiida/engine/transports.py @@ -13,12 +13,12 @@ import contextvars import logging import traceback -from typing import TYPE_CHECKING, Awaitable, Dict, Hashable, Iterator, Optional +from typing import TYPE_CHECKING, Awaitable, Dict, Hashable, Iterator, Optional, Union from aiida.orm import AuthInfo if TYPE_CHECKING: - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport _LOGGER = logging.getLogger(__name__) @@ -54,7 +54,7 @@ def loop(self) -> asyncio.AbstractEventLoop: return self._loop @contextlib.contextmanager - def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable['Transport']]: + def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable[Union['Transport', 'AsyncTransport']]]: """Request a transport from an authinfo. Because the client is not allowed to request a transport immediately they will instead be given back a future that can be awaited to get the transport:: diff --git a/src/aiida/orm/authinfos.py b/src/aiida/orm/authinfos.py index e87be97367..bff6ef849d 100644 --- a/src/aiida/orm/authinfos.py +++ b/src/aiida/orm/authinfos.py @@ -8,7 +8,7 @@ ########################################################################### """Module for the `AuthInfo` ORM class.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union from aiida.common import exceptions from aiida.manage import get_manager @@ -21,7 +21,7 @@ from aiida.orm import Computer, User from aiida.orm.implementation import StorageBackend from aiida.orm.implementation.authinfos import BackendAuthInfo # noqa: F401 - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport __all__ = ('AuthInfo',) @@ -166,7 +166,7 @@ def get_workdir(self) -> str: except KeyError: return self.computer.get_workdir() - def get_transport(self) -> 'Transport': + def get_transport(self) -> Union['Transport', 'AsyncTransport']: """Return a fully configured transport that can be used to connect to the computer set for this instance.""" computer = self.computer transport_type = computer.transport_type diff --git a/src/aiida/orm/computers.py b/src/aiida/orm/computers.py index bae925b25c..9bf12fbb2e 100644 --- a/src/aiida/orm/computers.py +++ b/src/aiida/orm/computers.py @@ -23,7 +23,7 @@ from aiida.orm import AuthInfo, User from aiida.orm.implementation import StorageBackend from aiida.schedulers import Scheduler - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport __all__ = ('Computer',) @@ -622,16 +622,16 @@ def is_user_enabled(self, user: 'User') -> bool: # Return False if the user is not configured (in a sense, it is disabled for that user) return False - def get_transport(self, user: Optional['User'] = None) -> 'Transport': + def get_transport(self, user: Optional['User'] = None) -> Union['Transport', 'AsyncTransport']: """Return a Transport class, configured with all correct parameters. The Transport is closed (meaning that if you want to run any operation with it, you have to open it first (i.e., e.g. for a SSH transport, you have - to open a connection). To do this you can call ``transports.open()``, or simply + to open a connection). To do this you can call ``transport.open()``, or simply run within a ``with`` statement:: transport = Computer.get_transport() with transport: - print(transports.whoami()) + print(transport.whoami()) :param user: if None, try to obtain a transport for the default user. Otherwise, pass a valid User. @@ -646,7 +646,7 @@ def get_transport(self, user: Optional['User'] = None) -> 'Transport': authinfo = authinfos.AuthInfo.get_collection(self.backend).get(dbcomputer=self, aiidauser=user) return authinfo.get_transport() - def get_transport_class(self) -> Type['Transport']: + def get_transport_class(self) -> Union[Type['Transport'], Type['AsyncTransport']]: """Get the transport class for this computer. Can be used to instantiate a transport instance.""" try: return TransportFactory(self.transport_type) diff --git a/src/aiida/orm/nodes/data/remote/base.py b/src/aiida/orm/nodes/data/remote/base.py index 1fc691d113..60e6f9bbee 100644 --- a/src/aiida/orm/nodes/data/remote/base.py +++ b/src/aiida/orm/nodes/data/remote/base.py @@ -117,7 +117,8 @@ def listdir_withattributes(self, path='.'): """Connects to the remote folder and lists the directory content. :param relpath: If 'relpath' is specified, lists the content of the given subfolder. - :return: a list of dictionaries, where the documentation is in :py:class:Transport.listdir_withattributes. + :return: a list of dictionaries, where the documentation + is in :py:class:Transport.listdir_withattributes. """ authinfo = self.get_authinfo() diff --git a/src/aiida/orm/nodes/process/calculation/calcjob.py b/src/aiida/orm/nodes/process/calculation/calcjob.py index a7cd20c88e..3d448d957b 100644 --- a/src/aiida/orm/nodes/process/calculation/calcjob.py +++ b/src/aiida/orm/nodes/process/calculation/calcjob.py @@ -26,7 +26,7 @@ from aiida.parsers import Parser from aiida.schedulers.datastructures import JobInfo, JobState from aiida.tools.calculations import CalculationTools - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport __all__ = ('CalcJobNode',) @@ -450,10 +450,11 @@ def get_authinfo(self) -> 'AuthInfo': return computer.get_authinfo(self.user) - def get_transport(self) -> 'Transport': + def get_transport(self) -> Union['Transport', 'AsyncTransport']: """Return the transport for this calculation. - :return: `Transport` configured with the `AuthInfo` associated to the computer of this node + :return: Union['Transport', 'AsyncTransport'] configured + with the `AuthInfo` associated to the computer of this node """ return self.get_authinfo().get_transport() diff --git a/src/aiida/orm/utils/remote.py b/src/aiida/orm/utils/remote.py index f55cedc35a..2a9846af7c 100644 --- a/src/aiida/orm/utils/remote.py +++ b/src/aiida/orm/utils/remote.py @@ -12,6 +12,7 @@ import os import typing as t +from typing import Union from aiida.orm.nodes.data.remote.base import RemoteData @@ -20,14 +21,14 @@ from aiida import orm from aiida.orm.implementation import StorageBackend - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport -def clean_remote(transport: Transport, path: str) -> None: +def clean_remote(transport: Union['Transport', 'AsyncTransport'], path: str) -> None: """Recursively remove a remote folder, with the given absolute path, and all its contents. The path should be made accessible through the transport channel, which should already be open - :param transport: an open Transport channel + :param transport: an open Union['Transport', 'AsyncTransport'] channel :param path: an absolute path on the remote made available through the transport """ if not isinstance(path, str): diff --git a/src/aiida/plugins/factories.py b/src/aiida/plugins/factories.py index d007ef0dd3..175abc9eb7 100644 --- a/src/aiida/plugins/factories.py +++ b/src/aiida/plugins/factories.py @@ -42,7 +42,7 @@ from aiida.schedulers import Scheduler from aiida.tools.data.orbital import Orbital from aiida.tools.dbimporters import DbImporter - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport def raise_invalid_type_error(entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...]) -> NoReturn: @@ -410,15 +410,19 @@ def StorageFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint @overload -def TransportFactory(entry_point_name: str, load: Literal[True] = True) -> Type['Transport']: ... +def TransportFactory( + entry_point_name: str, load: Literal[True] = True +) -> Union[Type['Transport'], Type['AsyncTransport']]: ... @overload def TransportFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ... -def TransportFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Transport']]: - """Return the `Transport` sub class registered under the given entry point. +def TransportFactory( + entry_point_name: str, load: bool = True +) -> Union[EntryPoint, Type['Transport'], Type['AsyncTransport']]: + """Return the Union['Transport', 'AsyncTransport'] sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. @@ -426,16 +430,16 @@ def TransportFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoi """ from inspect import isclass - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport entry_point_group = 'aiida.transports' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) - valid_classes = (Transport,) + valid_classes = (Transport, AsyncTransport) if not load: return entry_point - if isclass(entry_point) and issubclass(entry_point, Transport): + if isclass(entry_point) and (issubclass(entry_point, Transport) or issubclass(entry_point, AsyncTransport)): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) diff --git a/src/aiida/schedulers/plugins/direct.py b/src/aiida/schedulers/plugins/direct.py index 694ff93863..0bed55bda4 100644 --- a/src/aiida/schedulers/plugins/direct.py +++ b/src/aiida/schedulers/plugins/direct.py @@ -192,7 +192,7 @@ def _get_submit_command(self, submit_script): directory. IMPORTANT: submit_script should be already escaped. """ - submit_command = f'bash {submit_script} > /dev/null 2>&1 & echo $!' + submit_command = f'(bash {submit_script} > /dev/null 2>&1 & echo $!) &' self.logger.info(f'submitting with: {submit_command}') diff --git a/src/aiida/schedulers/scheduler.py b/src/aiida/schedulers/scheduler.py index 3cd4136984..3bb540c84a 100644 --- a/src/aiida/schedulers/scheduler.py +++ b/src/aiida/schedulers/scheduler.py @@ -12,6 +12,7 @@ import abc import typing as t +from typing import Union from aiida.common import exceptions, log, warnings from aiida.common.datastructures import CodeRunMode @@ -21,7 +22,7 @@ from aiida.schedulers.datastructures import JobInfo, JobResource, JobTemplate, JobTemplateCodeInfo if t.TYPE_CHECKING: - from aiida.transports import Transport + from aiida.transports import AsyncTransport, Transport __all__ = ('Scheduler', 'SchedulerError', 'SchedulerParsingError') @@ -365,7 +366,7 @@ def transport(self): return self._transport - def set_transport(self, transport: Transport): + def set_transport(self, transport: Union['Transport', 'AsyncTransport']): """Set the transport to be used to query the machine or to submit scripts. This class assumes that the transport is open and active. diff --git a/src/aiida/tools/pytest_fixtures/__init__.py b/src/aiida/tools/pytest_fixtures/__init__.py index c2729d16c5..e19d4c455e 100644 --- a/src/aiida/tools/pytest_fixtures/__init__.py +++ b/src/aiida/tools/pytest_fixtures/__init__.py @@ -22,6 +22,7 @@ aiida_computer, aiida_computer_local, aiida_computer_ssh, + aiida_computer_ssh_async, aiida_localhost, ssh_key, ) @@ -33,6 +34,7 @@ 'aiida_computer', 'aiida_computer_local', 'aiida_computer_ssh', + 'aiida_computer_ssh_async', 'aiida_config', 'aiida_config_factory', 'aiida_config_tmp', diff --git a/src/aiida/tools/pytest_fixtures/orm.py b/src/aiida/tools/pytest_fixtures/orm.py index 0ed7ea18d7..076eac2ddb 100644 --- a/src/aiida/tools/pytest_fixtures/orm.py +++ b/src/aiida/tools/pytest_fixtures/orm.py @@ -190,6 +190,38 @@ def factory(label: str | None = None, configure: bool = True) -> 'Computer': return factory +@pytest.fixture +def aiida_computer_ssh_async(aiida_computer) -> t.Callable[[], 'Computer']: + """Factory to return a :class:`aiida.orm.computers.Computer` instance with ``core.ssh_async`` transport. + + Usage:: + + def test(aiida_computer_ssh): + computer = aiida_computer_ssh(label='some-label', configure=True) + assert computer.transport_type == 'core.ssh_async' + assert computer.is_configured + + The factory has the following signature: + + :param label: The computer label. If not specified, a random UUID4 is used. + :param configure: Boolean, if ``True``, ensures the computer is configured, otherwise the computer is returned + as is. Note that if a computer with the given label already exists and it was configured before, the + computer will not be "un-"configured. If an unconfigured computer is absolutely required, make sure to first + delete the existing computer or specify another label. + :return: A stored computer instance. + """ + + def factory(label: str | None = None, configure: bool = True) -> 'Computer': + computer = aiida_computer(label=label, hostname='localhost', transport_type='core.ssh_async') + + if configure: + computer.configure() + + return computer + + return factory + + @pytest.fixture def aiida_localhost(aiida_computer_local) -> 'Computer': """Return a :class:`aiida.orm.computers.Computer` instance representing localhost with ``core.local`` transport. diff --git a/src/aiida/transports/__init__.py b/src/aiida/transports/__init__.py index eecd07c04f..0d36fe3980 100644 --- a/src/aiida/transports/__init__.py +++ b/src/aiida/transports/__init__.py @@ -16,8 +16,10 @@ from .transport import * __all__ = ( + 'AsyncTransport', 'SshTransport', 'Transport', + 'TransportPath', 'convert_to_bool', 'parse_sshconfig', ) diff --git a/src/aiida/transports/cli.py b/src/aiida/transports/cli.py index 6088eb08f6..5faa2d6f80 100644 --- a/src/aiida/transports/cli.py +++ b/src/aiida/transports/cli.py @@ -140,7 +140,7 @@ def transport_options(transport_type): """Decorate a command with all options for a computer configure subcommand for transport_type.""" def apply_options(func): - """Decorate the command functionn with the appropriate options for the transport type.""" + """Decorate the command function with the appropriate options for the transport type.""" options_list = list_transport_options(transport_type) options_list.reverse() func = arguments.COMPUTER(callback=partial(match_comp_transport, transport_type=transport_type))(func) diff --git a/src/aiida/transports/plugins/local.py b/src/aiida/transports/plugins/local.py index 8de49838e3..56fa042734 100644 --- a/src/aiida/transports/plugins/local.py +++ b/src/aiida/transports/plugins/local.py @@ -15,9 +15,11 @@ import os import shutil import subprocess +from typing import Optional +from aiida.common.warnings import warn_deprecation from aiida.transports import cli as transport_cli -from aiida.transports.transport import Transport, TransportInternalError +from aiida.transports.transport import Transport, TransportInternalError, TransportPath, path_to_str # refactor or raise the limit: issue #1784 @@ -92,7 +94,7 @@ def curdir(self): raise TransportInternalError('Error, local method called for LocalTransport without opening the channel first') - def chdir(self, path): + def chdir(self, path: TransportPath): """ PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE. `chdir()` is DEPRECATED and will be removed in the next major version. @@ -101,6 +103,11 @@ def chdir(self, path): :param path: path to cd into :raise OSError: if the directory does not have read attributes. """ + warn_deprecation( + '`chdir()` is deprecated and will be removed in the next major version.', + version=3, + ) + path = path_to_str(path) new_path = os.path.join(self.curdir, path) if not os.path.isdir(new_path): raise OSError(f"'{new_path}' is not a valid directory") @@ -109,13 +116,15 @@ def chdir(self, path): self._internal_dir = os.path.normpath(new_path) - def chown(self, path, uid, gid): + def chown(self, path: TransportPath, uid, gid): + path = path_to_str(path) os.chown(path, uid, gid) - def normalize(self, path='.'): + def normalize(self, path: TransportPath = '.'): """Normalizes path, eliminating double slashes, etc.. :param path: path to normalize """ + path = path_to_str(path) return os.path.realpath(os.path.join(self.curdir, path)) def getcwd(self): @@ -127,8 +136,9 @@ def getcwd(self): return self.curdir @staticmethod - def _os_path_split_asunder(path): + def _os_path_split_asunder(path: TransportPath): """Used by makedirs, Takes path (a str) and returns a list deconcatenating the path.""" + path = path_to_str(path) parts = [] while True: newpath, tail = os.path.split(path) @@ -142,7 +152,7 @@ def _os_path_split_asunder(path): parts.reverse() return parts - def makedirs(self, path, ignore_existing=False): + def makedirs(self, path: TransportPath, ignore_existing=False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -153,6 +163,7 @@ def makedirs(self, path, ignore_existing=False): :raise OSError: If the directory already exists and is not ignore_existing """ + path = path_to_str(path) # check to avoid creation of empty dirs path = os.path.normpath(path) @@ -168,7 +179,7 @@ def makedirs(self, path, ignore_existing=False): if not os.path.exists(this_dir): os.mkdir(this_dir) - def mkdir(self, path, ignore_existing=False): + def mkdir(self, path: TransportPath, ignore_existing=False): """Create a folder (directory) named path. :param path: name of the folder to create @@ -177,33 +188,37 @@ def mkdir(self, path, ignore_existing=False): :raise OSError: If the directory already exists. """ + path = path_to_str(path) if ignore_existing and self.isdir(path): return os.mkdir(os.path.join(self.curdir, path)) - def rmdir(self, path): + def rmdir(self, path: TransportPath): """Removes a folder at location path. :param path: path to remove """ + path = path_to_str(path) os.rmdir(os.path.join(self.curdir, path)) - def isdir(self, path): + def isdir(self, path: TransportPath): """Checks if 'path' is a directory. :return: a boolean """ + path = path_to_str(path) if not path: return False return os.path.isdir(os.path.join(self.curdir, path)) - def chmod(self, path, mode): + def chmod(self, path: TransportPath, mode): """Changes permission bits of object at path :param path: path to modify :param mode: permission bits :raise OSError: if path does not exist. """ + path = path_to_str(path) if not path: raise OSError('Directory not given in input') real_path = os.path.join(self.curdir, path) @@ -214,7 +229,7 @@ def chmod(self, path, mode): # please refactor: issue #1782 - def put(self, localpath, remotepath, *args, **kwargs): + def put(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Copies a file or a folder from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -228,6 +243,8 @@ def put(self, localpath, remotepath, *args, **kwargs): :raise OSError: if remotepath is not valid :raise ValueError: if localpath is not valid """ + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) from aiida.common.warnings import warn_deprecation if 'ignore_noexisting' in kwargs: @@ -294,7 +311,7 @@ def put(self, localpath, remotepath, *args, **kwargs): else: raise OSError(f'The local path {localpath} does not exist') - def putfile(self, localpath, remotepath, *args, **kwargs): + def putfile(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Copies a file from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -307,6 +324,9 @@ def putfile(self, localpath, remotepath, *args, **kwargs): :raise ValueError: if localpath is not valid :raise OSError: if localpath does not exist """ + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) + overwrite = kwargs.get('overwrite', args[0] if args else True) if not remotepath: raise OSError('Input remotepath to putfile must be a non empty string') @@ -325,7 +345,7 @@ def putfile(self, localpath, remotepath, *args, **kwargs): shutil.copyfile(localpath, the_destination) - def puttree(self, localpath, remotepath, *args, **kwargs): + def puttree(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Copies a folder recursively from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -340,6 +360,8 @@ def puttree(self, localpath, remotepath, *args, **kwargs): :raise ValueError: if localpath is not valid :raise OSError: if localpath does not exist """ + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) if not remotepath: @@ -365,11 +387,12 @@ def puttree(self, localpath, remotepath, *args, **kwargs): shutil.copytree(localpath, the_destination, symlinks=not dereference, dirs_exist_ok=overwrite) - def rmtree(self, path): + def rmtree(self, path: TransportPath): """Remove tree as rm -r would do :param path: a string to path """ + path = path_to_str(path) the_path = os.path.join(self.curdir, path) try: shutil.rmtree(the_path) @@ -383,7 +406,7 @@ def rmtree(self, path): # please refactor: issue #1781 - def get(self, remotepath, localpath, *args, **kwargs): + def get(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Copies a folder or a file recursively from 'remote' remotepath to 'local' localpath. Automatically redirects to getfile or gettree. @@ -398,6 +421,8 @@ def get(self, remotepath, localpath, *args, **kwargs): :raise OSError: if 'remote' remotepath is not valid :raise ValueError: if 'local' localpath is not valid """ + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) ignore_nonexisting = kwargs.get('ignore_nonexisting', args[2] if len(args) > 2 else False) @@ -449,7 +474,7 @@ def get(self, remotepath, localpath, *args, **kwargs): else: raise OSError(f'The remote path {remotepath} does not exist') - def getfile(self, remotepath, localpath, *args, **kwargs): + def getfile(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Copies a file recursively from 'remote' remotepath to 'local' localpath. @@ -462,6 +487,9 @@ def getfile(self, remotepath, localpath, *args, **kwargs): :raise ValueError: if 'local' localpath is not valid :raise OSError: if unintentionally overwriting """ + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) + if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') @@ -480,7 +508,7 @@ def getfile(self, remotepath, localpath, *args, **kwargs): shutil.copyfile(the_source, localpath) - def gettree(self, remotepath, localpath, *args, **kwargs): + def gettree(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Copies a folder recursively from 'remote' remotepath to 'local' localpath. @@ -493,6 +521,8 @@ def gettree(self, remotepath, localpath, *args, **kwargs): :raise ValueError: if 'local' localpath is not valid :raise OSError: if unintentionally overwriting """ + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) if not remotepath: @@ -519,7 +549,7 @@ def gettree(self, remotepath, localpath, *args, **kwargs): # please refactor: issue #1780 on github - def copy(self, remotesource, remotedestination, dereference=False, recursive=True): + def copy(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False, recursive=True): """Copies a file or a folder from 'remote' remotesource to 'remote' remotedestination. Automatically redirects to copyfile or copytree. @@ -532,6 +562,8 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru :raise ValueError: if 'remote' remotesource or remotedestinationis not valid :raise OSError: if remotesource does not exist """ + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copy must be a non empty object') if not remotedestination: @@ -579,7 +611,7 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru # With self.copytree, the (possible) relative path is OK self.copytree(remotesource, remotedestination, dereference) - def copyfile(self, remotesource, remotedestination, dereference=False): + def copyfile(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copies a file from 'remote' remotesource to 'remote' remotedestination. @@ -590,6 +622,8 @@ def copyfile(self, remotesource, remotedestination, dereference=False): :raise ValueError: if 'remote' remotesource or remotedestination is not valid :raise OSError: if remotesource does not exist """ + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copyfile must be a non empty object') if not remotedestination: @@ -605,7 +639,7 @@ def copyfile(self, remotesource, remotedestination, dereference=False): else: shutil.copyfile(the_source, the_destination) - def copytree(self, remotesource, remotedestination, dereference=False): + def copytree(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copies a folder from 'remote' remotesource to 'remote' remotedestination. @@ -616,6 +650,8 @@ def copytree(self, remotesource, remotedestination, dereference=False): :raise ValueError: if 'remote' remotesource or remotedestination is not valid :raise OSError: if remotesource does not exist """ + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copytree must be a non empty object') if not remotedestination: @@ -631,11 +667,12 @@ def copytree(self, remotesource, remotedestination, dereference=False): shutil.copytree(the_source, the_destination, symlinks=not dereference) - def get_attribute(self, path): + def get_attribute(self, path: TransportPath): """Returns an object FileAttribute, as specified in aiida.transports. :param path: the path of the given file. """ + path = path_to_str(path) from aiida.transports.util import FileAttribute os_attr = os.lstat(os.path.join(self.curdir, path)) @@ -646,10 +683,12 @@ def get_attribute(self, path): aiida_attr[key] = getattr(os_attr, key) return aiida_attr - def _local_listdir(self, path, pattern=None): + def _local_listdir(self, path: TransportPath, pattern=None): """Act on the local folder, for the rest, same as listdir.""" import re + path = path_to_str(path) + if not pattern: return os.listdir(path) @@ -663,12 +702,13 @@ def _local_listdir(self, path, pattern=None): base_dir += os.sep return [re.sub(base_dir, '', i) for i in filtered_list] - def listdir(self, path='.', pattern=None): + def listdir(self, path: TransportPath = '.', pattern=None): """:return: a list containing the names of the entries in the directory. :param path: default ='.' :param pattern: if set, returns the list of files matching pattern. Unix only. (Use to emulate ls * for example) """ + path = path_to_str(path) the_path = os.path.join(self.curdir, path).strip() if not pattern: try: @@ -685,20 +725,22 @@ def listdir(self, path='.', pattern=None): the_path += '/' return [re.sub(the_path, '', i) for i in filtered_list] - def remove(self, path): + def remove(self, path: TransportPath): """Removes a file at position path.""" + path = path_to_str(path) os.remove(os.path.join(self.curdir, path)) - def isfile(self, path): + def isfile(self, path: TransportPath): """Checks if object at path is a file. Returns a boolean. """ + path = path_to_str(path) if not path: return False return os.path.isfile(os.path.join(self.curdir, path)) @contextlib.contextmanager - def _exec_command_internal(self, command, workdir=None, **kwargs): + def _exec_command_internal(self, command, workdir: Optional[TransportPath] = None, **kwargs): """Executes the specified command in bash login shell. @@ -723,12 +765,13 @@ def _exec_command_internal(self, command, workdir=None, **kwargs): """ from aiida.common.escaping import escape_for_bash + if workdir: + workdir = path_to_str(workdir) # Note: The outer shell will eat one level of escaping, while # 'bash -l -c ...' will eat another. Thus, we need to escape again. bash_commmand = f'{self._bash_command_str}-c ' command = bash_commmand + escape_for_bash(command) - if workdir: cwd = workdir else: @@ -745,7 +788,7 @@ def _exec_command_internal(self, command, workdir=None, **kwargs): ) as process: yield process - def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs): + def exec_command_wait_bytes(self, command, stdin=None, workdir: Optional[TransportPath] = None, **kwargs): """Executes the specified command and waits for it to finish. :param command: the command to execute @@ -757,6 +800,8 @@ def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs): :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both bytes and the return_value is an int. """ + if workdir: + workdir = path_to_str(workdir) with self._exec_command_internal(command, workdir) as process: if stdin is not None: # Implicitly assume that the desired encoding is 'utf-8' if I receive a string. @@ -799,7 +844,7 @@ def line_encoder(iterator, encoding='utf-8'): return retval, output_text, stderr_text - def gotocomputer_command(self, remotedir): + def gotocomputer_command(self, remotedir: TransportPath): """Return a string to be run using os.system in order to connect via the transport to the remote directory. @@ -810,11 +855,12 @@ def gotocomputer_command(self, remotedir): :param str remotedir: the full path of the remote directory """ + remotedir = path_to_str(remotedir) connect_string = self._gotocomputer_string(remotedir) cmd = f'bash -c {connect_string}' return cmd - def rename(self, oldpath, newpath): + def rename(self, oldpath: TransportPath, newpath: TransportPath): """Rename a file or folder from oldpath to newpath. :param str oldpath: existing name of the file or folder @@ -823,6 +869,8 @@ def rename(self, oldpath, newpath): :raises OSError: if src/dst is not found :raises ValueError: if src/dst is not a valid string """ + oldpath = path_to_str(oldpath) + newpath = path_to_str(newpath) if not oldpath: raise ValueError(f'Source {oldpath} is not a valid string') if not newpath: @@ -834,15 +882,15 @@ def rename(self, oldpath, newpath): shutil.move(oldpath, newpath) - def symlink(self, remotesource, remotedestination): + def symlink(self, remotesource: TransportPath, remotedestination: TransportPath): """Create a symbolic link between the remote source and the remote remotedestination :param remotesource: remote source. Can contain a pattern. :param remotedestination: remote destination """ - remotesource = os.path.normpath(remotesource) - remotedestination = os.path.normpath(remotedestination) + remotesource = os.path.normpath(path_to_str(remotesource)) + remotedestination = os.path.normpath(path_to_str(remotedestination)) if self.has_magic(remotesource): if self.has_magic(remotedestination): @@ -861,8 +909,9 @@ def symlink(self, remotesource, remotedestination): except OSError: raise OSError(f'!!: {remotesource}, {self.curdir}, {remotedestination}') - def path_exists(self, path): + def path_exists(self, path: TransportPath): """Check if path exists""" + path = path_to_str(path) return os.path.exists(os.path.join(self.curdir, path)) diff --git a/src/aiida/transports/plugins/ssh.py b/src/aiida/transports/plugins/ssh.py index 6858da5d2a..8cfe607a34 100644 --- a/src/aiida/transports/plugins/ssh.py +++ b/src/aiida/transports/plugins/ssh.py @@ -19,8 +19,9 @@ from aiida.cmdline.params import options from aiida.cmdline.params.types.path import AbsolutePathOrEmptyParamType from aiida.common.escaping import escape_for_bash +from aiida.common.warnings import warn_deprecation -from ..transport import Transport, TransportInternalError +from ..transport import Transport, TransportInternalError, TransportPath, path_to_str __all__ = ('SshTransport', 'convert_to_bool', 'parse_sshconfig') @@ -230,6 +231,10 @@ class SshTransport(Transport): # if too large commands are sent, clogging the outputs or logs _MAX_EXEC_COMMAND_LOG_SIZE = None + # NOTE: all the methods that start with _get_ are class methods that + # return a suggestion for the specific field. They are being used in + # a function called transport_option_default in transports/cli.py, + # during an interactive `verdi computer configure` command. @classmethod def _get_username_suggestion_string(cls, computer): """Return a suggestion for the specific field.""" @@ -580,7 +585,7 @@ def __str__(self): return f"{'OPEN' if self._is_open else 'CLOSED'} [{conn_info}]" - def chdir(self, path): + def chdir(self, path: TransportPath): """ PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE. `chdir()` is DEPRECATED and will be removed in the next major version. @@ -590,8 +595,13 @@ def chdir(self, path): Differently from paramiko, if you pass None to chdir, nothing happens and the cwd is unchanged. """ + warn_deprecation( + '`chdir()` is deprecated and will be removed in the next major version.', + version=3, + ) from paramiko.sftp import SFTPError + path = path_to_str(path) old_path = self.sftp.getcwd() if path is not None: try: @@ -618,11 +628,13 @@ def chdir(self, path): self.chdir(old_path) raise OSError(str(exc)) - def normalize(self, path='.'): + def normalize(self, path: TransportPath = '.'): """Returns the normalized path (removing double slashes, etc...)""" + path = path_to_str(path) + return self.sftp.normalize(path) - def stat(self, path): + def stat(self, path: TransportPath): """Retrieve information about a file on the remote system. The return value is an object whose attributes correspond to the attributes of Python's ``stat`` structure as returned by ``os.stat``, except that it @@ -635,9 +647,11 @@ def stat(self, path): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ + path = path_to_str(path) + return self.sftp.stat(path) - def lstat(self, path): + def lstat(self, path: TransportPath): """Retrieve information about a file on the remote system, without following symbolic links (shortcuts). This otherwise behaves exactly the same as `stat`. @@ -647,6 +661,8 @@ def lstat(self, path): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ + path = path_to_str(path) + return self.sftp.lstat(path) def getcwd(self): @@ -659,9 +675,13 @@ def getcwd(self): this method will return None. But in __enter__ this is set explicitly, so this should never happen within this class. """ + warn_deprecation( + '`chdir()` is deprecated and will be removed in the next major version.', + version=3, + ) return self.sftp.getcwd() - def makedirs(self, path, ignore_existing=False): + def makedirs(self, path: TransportPath, ignore_existing: bool = False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -676,6 +696,8 @@ def makedirs(self, path, ignore_existing=False): :raise OSError: If the directory already exists. """ + path = path_to_str(path) + # check to avoid creation of empty dirs path = os.path.normpath(path) @@ -697,7 +719,7 @@ def makedirs(self, path, ignore_existing=False): if not self.isdir(this_dir): self.mkdir(this_dir) - def mkdir(self, path, ignore_existing=False): + def mkdir(self, path: TransportPath, ignore_existing: bool = False): """Create a folder (directory) named path. :param path: name of the folder to create @@ -706,6 +728,8 @@ def mkdir(self, path, ignore_existing=False): :raise OSError: If the directory already exists. """ + path = path_to_str(path) + if ignore_existing and self.isdir(path): return @@ -725,7 +749,7 @@ def mkdir(self, path, ignore_existing=False): 'or the directory already exists? ({})'.format(path, self.getcwd(), exc) ) - def rmtree(self, path): + def rmtree(self, path: TransportPath): """Remove a file or a directory at path, recursively Flags used: -r: recursive copy; -f: force, makes the command non interactive; @@ -733,6 +757,7 @@ def rmtree(self, path): :raise OSError: if the rm execution failed. """ + path = path_to_str(path) # Assuming linux rm command! rm_exe = 'rm' @@ -752,25 +777,29 @@ def rmtree(self, path): self.logger.error(f"Problem executing rm. Exit code: {retval}, stdout: '{stdout}', stderr: '{stderr}'") raise OSError(f'Error while executing rm. Exit code: {retval}') - def rmdir(self, path): + def rmdir(self, path: TransportPath): """Remove the folder named 'path' if empty.""" + path = path_to_str(path) self.sftp.rmdir(path) - def chown(self, path, uid, gid): + def chown(self, path: TransportPath, uid, gid): """Change owner permissions of a file. For now, this is not implemented for the SSH transport. """ raise NotImplementedError - def isdir(self, path): + def isdir(self, path: TransportPath): """Return True if the given path is a directory, False otherwise. Return False also if the path does not exist. """ # Return False on empty string (paramiko would map this to the local # folder instead) + path = path_to_str(path) + if not path: return False + path = path_to_str(path) try: return S_ISDIR(self.stat(path).st_mode) except OSError as exc: @@ -779,21 +808,24 @@ def isdir(self, path): return False raise # Typically if I don't have permissions (errno=13) - def chmod(self, path, mode): + def chmod(self, path: TransportPath, mode): """Change permissions to path :param path: path to file :param mode: new permission bits (integer) """ + path = path_to_str(path) + if not path: raise OSError('Input path is an empty argument.') return self.sftp.chmod(path, mode) @staticmethod - def _os_path_split_asunder(path): - """Used by makedirs. Takes path (a str) + def _os_path_split_asunder(path: TransportPath): + """Used by makedirs. Takes path and returns a list deconcatenating the path """ + path = path_to_str(path) parts = [] while True: newpath, tail = os.path.split(path) @@ -807,7 +839,15 @@ def _os_path_split_asunder(path): parts.reverse() return parts - def put(self, localpath, remotepath, callback=None, dereference=True, overwrite=True, ignore_nonexisting=False): + def put( + self, + localpath: TransportPath, + remotepath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ignore_nonexisting: bool = False, + ): """Put a file or a folder from local to remote. Redirects to putfile or puttree. @@ -821,6 +861,9 @@ def put(self, localpath, remotepath, callback=None, dereference=True, overwrite= :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist """ + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) + if not dereference: raise NotImplementedError @@ -871,7 +914,14 @@ def put(self, localpath, remotepath, callback=None, dereference=True, overwrite= elif not ignore_nonexisting: raise OSError(f'The local path {localpath} does not exist') - def putfile(self, localpath, remotepath, callback=None, dereference=True, overwrite=True): + def putfile( + self, + localpath: TransportPath, + remotepath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ): """Put a file from local to remote. :param localpath: an (absolute) local path @@ -883,6 +933,9 @@ def putfile(self, localpath, remotepath, callback=None, dereference=True, overwr :raise OSError: if the localpath does not exist, or unintentionally overwriting """ + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) + if not dereference: raise NotImplementedError @@ -894,7 +947,14 @@ def putfile(self, localpath, remotepath, callback=None, dereference=True, overwr return self.sftp.put(localpath, remotepath, callback=callback) - def puttree(self, localpath, remotepath, callback=None, dereference=True, overwrite=True): + def puttree( + self, + localpath: TransportPath, + remotepath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ): """Put a folder recursively from local to remote. By default, overwrite. @@ -913,6 +973,9 @@ def puttree(self, localpath, remotepath, callback=None, dereference=True, overwr .. note:: setting dereference equal to True could cause infinite loops. see os.walk() documentation """ + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) + if not dereference: raise NotImplementedError @@ -958,7 +1021,15 @@ def puttree(self, localpath, remotepath, callback=None, dereference=True, overwr this_remote_file = os.path.join(remotepath, this_basename, this_file) self.putfile(this_local_file, this_remote_file) - def get(self, remotepath, localpath, callback=None, dereference=True, overwrite=True, ignore_nonexisting=False): + def get( + self, + remotepath: TransportPath, + localpath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ignore_nonexisting: bool = False, + ): """Get a file or folder from remote to local. Redirects to getfile or gettree. @@ -973,6 +1044,9 @@ def get(self, remotepath, localpath, callback=None, dereference=True, overwrite= :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found """ + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) + if not dereference: raise NotImplementedError @@ -1020,7 +1094,14 @@ def get(self, remotepath, localpath, callback=None, dereference=True, overwrite= else: raise OSError(f'The remote path {remotepath} does not exist') - def getfile(self, remotepath, localpath, callback=None, dereference=True, overwrite=True): + def getfile( + self, + remotepath: TransportPath, + localpath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ): """Get a file from remote to local. :param remotepath: a remote path @@ -1031,6 +1112,9 @@ def getfile(self, remotepath, localpath, callback=None, dereference=True, overwr :raise ValueError: if local path is invalid :raise OSError: if unintentionally overwriting """ + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) + if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') @@ -1050,7 +1134,14 @@ def getfile(self, remotepath, localpath, callback=None, dereference=True, overwr pass raise - def gettree(self, remotepath, localpath, callback=None, dereference=True, overwrite=True): + def gettree( + self, + remotepath: TransportPath, + localpath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ): """Get a folder recursively from remote to local. :param remotepath: a remote path @@ -1059,12 +1150,14 @@ def gettree(self, remotepath, localpath, callback=None, dereference=True, overwr Default = True (default behaviour in paramiko). False is not implemented. :param overwrite: if True overwrites files and folders. - Default = False + Default = True :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found :raise OSError: if unintentionally overwriting """ + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) if not dereference: raise NotImplementedError @@ -1101,10 +1194,11 @@ def gettree(self, remotepath, localpath, callback=None, dereference=True, overwr else: self.getfile(os.path.join(remotepath, item), os.path.join(dest, item)) - def get_attribute(self, path): + def get_attribute(self, path: TransportPath): """Returns the object Fileattribute, specified in aiida.transports Receives in input the path of a given file. """ + path = path_to_str(path) from aiida.transports.util import FileAttribute paramiko_attr = self.lstat(path) @@ -1115,13 +1209,25 @@ def get_attribute(self, path): aiida_attr[key] = getattr(paramiko_attr, key) return aiida_attr - def copyfile(self, remotesource, remotedestination, dereference=False): + def copyfile(self, remotesource: TransportPath, remotedestination: TransportPath, dereference: bool = False): + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) + return self.copy(remotesource, remotedestination, dereference) - def copytree(self, remotesource, remotedestination, dereference=False): + def copytree(self, remotesource: TransportPath, remotedestination: TransportPath, dereference: bool = False): + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) + return self.copy(remotesource, remotedestination, dereference, recursive=True) - def copy(self, remotesource, remotedestination, dereference=False, recursive=True): + def copy( + self, + remotesource: TransportPath, + remotedestination: TransportPath, + dereference: bool = False, + recursive: bool = True, + ): """Copy a file or a directory from remote source to remote destination. Flags used: ``-r``: recursive copy; ``-f``: force, makes the command non interactive; ``-L`` follows symbolic links @@ -1138,6 +1244,9 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru .. note:: setting dereference equal to True could cause infinite loops. """ + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) + # In the majority of cases, we should deal with linux cp commands cp_flags = '-f' if recursive: @@ -1179,7 +1288,7 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru else: self._exec_cp(cp_exe, cp_flags, remotesource, remotedestination) - def _exec_cp(self, cp_exe, cp_flags, src, dst): + def _exec_cp(self, cp_exe: str, cp_flags: str, src: str, dst: str): """Execute the ``cp`` command on the remote machine.""" # to simplify writing the above copy function command = f'{cp_exe} {cp_flags} {escape_for_bash(src)} {escape_for_bash(dst)}' @@ -1205,7 +1314,7 @@ def _exec_cp(self, cp_exe, cp_flags, src, dst): ) @staticmethod - def _local_listdir(path, pattern=None): + def _local_listdir(path: str, pattern=None): """Acts on the local folder, for the rest, same as listdir""" if not pattern: return os.listdir(path) @@ -1219,13 +1328,15 @@ def _local_listdir(path, pattern=None): base_dir += os.sep return [re.sub(base_dir, '', i) for i in filtered_list] - def listdir(self, path='.', pattern=None): + def listdir(self, path: TransportPath = '.', pattern=None): """Get the list of files at path. :param path: default = '.' :param pattern: returns the list of files matching pattern. Unix only. (Use to emulate ``ls *`` for example) """ + path = path_to_str(path) + if path.startswith('/'): abs_dir = path else: @@ -1239,33 +1350,41 @@ def listdir(self, path='.', pattern=None): abs_dir += '/' return [re.sub(abs_dir, '', i) for i in filtered_list] - def remove(self, path): + def remove(self, path: TransportPath): """Remove a single file at 'path'""" + path = path_to_str(path) return self.sftp.remove(path) - def rename(self, oldpath, newpath): + def rename(self, oldpath: TransportPath, newpath: TransportPath): """Rename a file or folder from oldpath to newpath. :param str oldpath: existing name of the file or folder :param str newpath: new name for the file or folder :raises OSError: if oldpath/newpath is not found - :raises ValueError: if sroldpathc/newpath is not a valid string + :raises ValueError: if sroldpathc/newpath is not a valid path """ if not oldpath: - raise ValueError(f'Source {oldpath} is not a valid string') + raise ValueError(f'Source {oldpath} is not a valid path') if not newpath: - raise ValueError(f'Destination {newpath} is not a valid string') + raise ValueError(f'Destination {newpath} is not a valid path') + + oldpath = path_to_str(oldpath) + newpath = path_to_str(newpath) + if not self.isfile(oldpath): if not self.isdir(oldpath): raise OSError(f'Source {oldpath} does not exist') + # TODO: this seems to be a bug (?) + # why to raise an OSError if the newpath does not exist? + # ofcourse newpath shouldn't exist, that's why we are renaming it! if not self.isfile(newpath): if not self.isdir(newpath): raise OSError(f'Destination {newpath} does not exist') return self.sftp.rename(oldpath, newpath) - def isfile(self, path): + def isfile(self, path: TransportPath): """Return True if the given path is a file, False otherwise. Return False also if the path does not exist. """ @@ -1274,6 +1393,8 @@ def isfile(self, path): # but this is just to be sure if not path: return False + + path = path_to_str(path) try: self.logger.debug( f"stat for path '{path}' ('{self.normalize(path)}'): {self.stat(path)} [{self.stat(path).st_mode}]" @@ -1334,7 +1455,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1, work return stdin, stdout, stderr, channel def exec_command_wait_bytes( - self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir=None + self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir: TransportPath = None ): """Executes the specified command and waits for it to finish. @@ -1354,6 +1475,9 @@ def exec_command_wait_bytes( import socket import time + if workdir: + workdir = path_to_str(workdir) + ssh_stdin, stdout, stderr, channel = self._exec_command_internal( command, combine_stderr, bufsize=bufsize, workdir=workdir ) @@ -1447,10 +1571,12 @@ def exec_command_wait_bytes( return (retval, b''.join(stdout_bytes), b''.join(stderr_bytes)) - def gotocomputer_command(self, remotedir): + def gotocomputer_command(self, remotedir: TransportPath): """Specific gotocomputer string to connect to a given remote computer via ssh and directly go to the calculation folder. """ + remotedir = path_to_str(remotedir) + further_params = [] if 'username' in self._connect_args: further_params.append(f"-l {escape_for_bash(self._connect_args['username'])}") @@ -1473,21 +1599,25 @@ def gotocomputer_command(self, remotedir): cmd = f'ssh -t {self._machine} {further_params_str} {connect_string}' return cmd - def _symlink(self, source, dest): + def _symlink(self, source: TransportPath, dest: TransportPath): """Wrap SFTP symlink call without breaking API :param source: source of link :param dest: link to create """ + source = path_to_str(source) + dest = path_to_str(dest) self.sftp.symlink(source, dest) - def symlink(self, remotesource, remotedestination): + def symlink(self, remotesource: TransportPath, remotedestination: TransportPath): """Create a symbolic link between the remote source and the remote destination. :param remotesource: remote source. Can contain a pattern. :param remotedestination: remote destination """ + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) # paramiko gives some errors if path is starting with '.' source = os.path.normpath(remotesource) dest = os.path.normpath(remotedestination) @@ -1495,7 +1625,7 @@ def symlink(self, remotesource, remotedestination): if self.has_magic(source): if self.has_magic(dest): # if there are patterns in dest, I don't know which name to assign - raise ValueError('Remotedestination cannot have patterns') + raise ValueError('`remotedestination` cannot have patterns') # find all files matching pattern for this_source in self.glob(source): @@ -1505,10 +1635,12 @@ def symlink(self, remotesource, remotedestination): else: self._symlink(source, dest) - def path_exists(self, path): + def path_exists(self, path: TransportPath): """Check if path exists""" import errno + path = path_to_str(path) + try: self.stat(path) except OSError as exc: diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py new file mode 100644 index 0000000000..307cffbbfa --- /dev/null +++ b/src/aiida/transports/plugins/ssh_async.py @@ -0,0 +1,1271 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Plugin for transport over SSH asynchronously.""" + +## TODO: put & get methods could be simplified with the asyncssh.sftp.mget() & put() method or sftp.glob() +import asyncio +import glob +import os +from pathlib import Path, PurePath +from typing import Optional, Union + +import asyncssh +import click +from asyncssh import SFTPFileAlreadyExists + +from aiida.common.escaping import escape_for_bash +from aiida.common.exceptions import InvalidOperation +from aiida.transports.transport import ( + AsyncTransport, + Transport, + TransportInternalError, + TransportPath, + path_to_str, + validate_positive_number, +) + +__all__ = ('AsyncSshTransport',) + + +def validate_script(ctx, param, value: str): + if value == 'None': + return value + if not os.path.isabs(value): + raise click.BadParameter(f'{value} is not an absolute path') + if not os.path.isfile(value): + raise click.BadParameter(f'The script file: {value} does not exist') + if not os.access(value, os.X_OK): + raise click.BadParameter(f'The script {value} is not executable') + return value + + +def validate_machine(ctx, param, value: str): + async def attempt_connection(): + try: + await asyncssh.connect(value) + except Exception: + return False + return True + + if not asyncio.run(attempt_connection()): + raise click.BadParameter("Couldn't connect! " 'Please make sure `ssh {value}` would work without password') + else: + click.echo(f'`ssh {value}` successful!') + + return value + + +class AsyncSshTransport(AsyncTransport): + """Transport plugin via SSH, asynchronously.""" + + _DEFAULT_max_io_allowed = 8 + + # note, I intentionally wanted to keep connection parameters as simple as possible. + _valid_auth_options = [ + ( + # the underscore is added to avoid conflict with the machine property + # which is passed to __init__ as parameter `machine=computer.hostname` + 'machine_or_host', + { + 'type': str, + 'prompt': 'Machine(or host) name as in `ssh ` command.' + ' (It should be a password-less setup)', + 'help': 'Password-less host-setup to connect, as in command `ssh `. ' + "You'll need to have a `Host ` entry defined in your `~/.ssh/config` file.", + 'non_interactive_default': True, + 'callback': validate_machine, + }, + ), + ( + 'max_io_allowed', + { + 'type': int, + 'default': _DEFAULT_max_io_allowed, + 'prompt': 'Maximum number of concurrent I/O operations.', + 'help': 'Depends on various factors, such as your network bandwidth, the server load, etc.' + ' (An experimental number)', + 'non_interactive_default': True, + 'callback': validate_positive_number, + }, + ), + ( + 'script_before', + { + 'type': str, + 'default': 'None', + 'prompt': 'Local script to run *before* opening connection (path)', + 'help': ' (optional) Specify a script to run *before* opening SSH connection. ' + 'The script should be executable', + 'non_interactive_default': True, + 'callback': validate_script, + }, + ), + ] + + @classmethod + def _get_machine_suggestion_string(cls, computer): + """Return a suggestion for the parameter machine.""" + # Originally set as 'Hostname' during `verdi computer setup` + # and is passed as `machine=computer.hostname` in the codebase + # unfortunately, name of hostname and machine are used interchangeably in the aiida-core codebase + # TODO: open an issue to unify the naming + return computer.hostname + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # the machine is passed as `machine=computer.hostname` in the codebase + # 'machine' is immutable. + # 'machine_or_host' is mutable, so it can be changed via command: + # 'verdi computer configure core.ssh_async