Skip to content

Commit

Permalink
[minor] restructure create method (#551)
Browse files Browse the repository at this point in the history
* restructure create method

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* remove unused line

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jan-janssen and pre-commit-ci[bot] authored Feb 1, 2025
1 parent 565bb32 commit 2a5c109
Showing 1 changed file with 42 additions and 15 deletions.
57 changes: 42 additions & 15 deletions executorlib/interactive/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,15 @@ def create_executor(
of the individual function.
init_function (None): optional function to preset arguments for functions which are submitted later
"""
check_init_function(block_allocation=block_allocation, init_function=init_function)
if flux_executor is not None and backend != "flux_allocation":
backend = "flux_allocation"
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
cores_per_worker = resource_dict.get("cores", 1)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost
if backend == "flux_allocation":
check_init_function(
block_allocation=block_allocation, init_function=init_function
)
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost
check_oversubscribe(
oversubscribe=resource_dict.get("openmpi_oversubscribe", False)
)
Expand All @@ -100,40 +101,41 @@ def create_executor(
return create_flux_allocation_executor(
max_workers=max_workers,
max_cores=max_cores,
cores_per_worker=cores_per_worker,
cache_directory=cache_directory,
resource_dict=resource_dict,
flux_executor=flux_executor,
flux_executor_pmi_mode=flux_executor_pmi_mode,
flux_executor_nesting=flux_executor_nesting,
flux_log_files=flux_log_files,
hostname_localhost=hostname_localhost,
block_allocation=block_allocation,
init_function=init_function,
)
elif backend == "slurm_allocation":
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
check_executor(executor=flux_executor)
check_nested_flux_executor(nested_flux_executor=flux_executor_nesting)
check_flux_log_files(flux_log_files=flux_log_files)
return create_slurm_allocation_executor(
max_workers=max_workers,
max_cores=max_cores,
cores_per_worker=cores_per_worker,
cache_directory=cache_directory,
resource_dict=resource_dict,
hostname_localhost=hostname_localhost,
block_allocation=block_allocation,
init_function=init_function,
)
elif backend == "local":
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
check_executor(executor=flux_executor)
check_nested_flux_executor(nested_flux_executor=flux_executor_nesting)
check_flux_log_files(flux_log_files=flux_log_files)
check_gpus_per_worker(gpus_per_worker=resource_dict.get("gpus_per_core", 0))
check_command_line_argument_lst(
command_line_argument_lst=resource_dict.get("slurm_cmd_args", [])
)
return create_local_executor(
max_workers=max_workers,
max_cores=max_cores,
cores_per_worker=cores_per_worker,
cache_directory=cache_directory,
resource_dict=resource_dict,
hostname_localhost=hostname_localhost,
block_allocation=block_allocation,
init_function=init_function,
)
Expand All @@ -146,15 +148,25 @@ def create_executor(
def create_flux_allocation_executor(
max_workers: Optional[int] = None,
max_cores: Optional[int] = None,
cores_per_worker: int = 1,
cache_directory: Optional[str] = None,
resource_dict: dict = {},
flux_executor=None,
flux_executor_pmi_mode: Optional[str] = None,
flux_executor_nesting: bool = False,
flux_log_files: bool = False,
hostname_localhost: Optional[bool] = None,
block_allocation: bool = False,
init_function: Optional[Callable] = None,
) -> Union[InteractiveStepExecutor, InteractiveExecutor]:
check_init_function(block_allocation=block_allocation, init_function=init_function)
check_pmi(backend="flux_allocation", pmi=flux_executor_pmi_mode)
cores_per_worker = resource_dict.get("cores", 1)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost
check_oversubscribe(oversubscribe=resource_dict.get("openmpi_oversubscribe", False))
check_command_line_argument_lst(
command_line_argument_lst=resource_dict.get("slurm_cmd_args", [])
)
if "openmpi_oversubscribe" in resource_dict.keys():
del resource_dict["openmpi_oversubscribe"]
if "slurm_cmd_args" in resource_dict.keys():
Expand Down Expand Up @@ -193,11 +205,16 @@ def create_flux_allocation_executor(
def create_slurm_allocation_executor(
max_workers: Optional[int] = None,
max_cores: Optional[int] = None,
cores_per_worker: int = 1,
cache_directory: Optional[str] = None,
resource_dict: dict = {},
hostname_localhost: Optional[bool] = None,
block_allocation: bool = False,
init_function: Optional[Callable] = None,
) -> Union[InteractiveStepExecutor, InteractiveExecutor]:
check_init_function(block_allocation=block_allocation, init_function=init_function)
cores_per_worker = resource_dict.get("cores", 1)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost
if block_allocation:
resource_dict["init_function"] = init_function
max_workers = validate_number_of_cores(
Expand Down Expand Up @@ -228,11 +245,21 @@ def create_slurm_allocation_executor(
def create_local_executor(
max_workers: Optional[int] = None,
max_cores: Optional[int] = None,
cores_per_worker: int = 1,
cache_directory: Optional[str] = None,
resource_dict: dict = {},
hostname_localhost: Optional[bool] = None,
block_allocation: bool = False,
init_function: Optional[Callable] = None,
) -> Union[InteractiveStepExecutor, InteractiveExecutor]:
check_init_function(block_allocation=block_allocation, init_function=init_function)
cores_per_worker = resource_dict.get("cores", 1)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost

check_gpus_per_worker(gpus_per_worker=resource_dict.get("gpus_per_core", 0))
check_command_line_argument_lst(
command_line_argument_lst=resource_dict.get("slurm_cmd_args", [])
)
if "threads_per_core" in resource_dict.keys():
del resource_dict["threads_per_core"]
if "gpus_per_core" in resource_dict.keys():
Expand Down

0 comments on commit 2a5c109

Please sign in to comment.