-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
112 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,14 +99,13 @@ See the `tests` folder for an advanced example of training a PyTorch model with | |
### `CLASS slurm_sweeps.Experiment` | ||
|
||
```python | ||
def __init__( | ||
self, | ||
class Experiment( | ||
train: Callable, | ||
cfg: Dict, | ||
name: str = "MySweep", | ||
local_dir: Union[str, Path] = "./slurm_sweeps", | ||
backend: Optional[Backend] = None, | ||
local_dir: Union[str, Path] = "./slurm-sweeps", | ||
asha: Optional[ASHA] = None, | ||
slurm_cfg: Optional[SlurmCfg] = None, | ||
restore: bool = False, | ||
overwrite: bool = False, | ||
) | ||
|
@@ -120,59 +119,130 @@ Set up an HPO experiment. | |
- `cfg` - A dict passed on to the `train` function. | ||
It must contain the search spaces via `slurm_sweeps.Uniform`, `slurm_sweeps.Choice`, etc. | ||
- `name` - The name of the experiment. | ||
- `local_dir` - Where to store and run the experiments. In this directory | ||
- `local_dir` - Where to store and run the experiments. In this directory, | ||
we will create the database `slurm_sweeps.db` and a folder with the experiment name. | ||
- `backend` - A backend to execute the trials. By default, we choose the `SlurmBackend` if Slurm is available, | ||
otherwise we choose the standard `Backend` that simply executes the trial in another process. | ||
- `slurm_cfg` - The configuration of the Slurm backend responsible for running the trials. | ||
We automatically choose this backend when slurm sweeps is used within an sbatch script. | ||
- `asha` - An optional ASHA instance to cancel less promising trials. | ||
- `restore` - Restore an experiment with the same name? | ||
- `overwrite` - Overwrite an existing experiment with the same name? | ||
|
||
#### `Experiment.name` | ||
|
||
```python | ||
@property | ||
def name() -> str | ||
``` | ||
|
||
The name of the experiment. | ||
|
||
#### `Experiment.local_dir` | ||
|
||
```python | ||
@property | ||
def local_dir() -> Path | ||
``` | ||
|
||
The local directory of the experiment. | ||
|
||
#### `Experiment.run` | ||
|
||
```python | ||
def run( | ||
self, | ||
n_trials: int = 1, | ||
max_concurrent_trials: Optional[int] = None, | ||
summary_interval_in_sec: float = 5.0, | ||
nr_of_rows_in_summary: int = 10, | ||
summarize_cfg_and_metrics: Union[bool, List[str]] = True, | ||
summarize_cfg_and_metrics: Union[bool, List[str]] = True | ||
) -> pd.DataFrame | ||
``` | ||
|
||
Run the experiment. | ||
|
||
**Arguments**: | ||
|
||
- `n_trials` - Number of trials to run. For grid searches this parameter is ignored. | ||
- `n_trials` - Number of trials to run. For grid searches, this parameter is ignored. | ||
- `max_concurrent_trials` - The maximum number of trials running concurrently. By default, we will set this to | ||
the number of cpus available, or the number of total Slurm tasks divided by the number of trial Slurm | ||
tasks requested. | ||
the number of cpus available, or the number of total Slurm tasks divided by the number of tasks | ||
requested per trial. | ||
- `summary_interval_in_sec` - Print a summary of the experiment every x seconds. | ||
- `nr_of_rows_in_summary` - How many rows of the summary table should we print? | ||
- `summarize_cfg_and_metrics` - Should we include the cfg and the metrics in the summary table? | ||
You can also pass in a list of strings to only select a few cfg and metric keys. | ||
|
||
**Returns**: | ||
|
||
A DataFrame of the database. | ||
A summary of the trials in a pandas DataFrame. | ||
|
||
### `CLASS slurm_sweeps.SlurmBackend` | ||
### `CLASS slurm_sweeps.ASHA` | ||
|
||
```python | ||
def __init__( | ||
self, | ||
exclusive: bool = True, | ||
nodes: int = 1, | ||
ntasks: int = 1, | ||
args: str = "" | ||
class ASHA( | ||
metric: str, | ||
mode: str, | ||
reduction_factor: int = 4, | ||
min_t: int = 1, | ||
max_t: int = 50, | ||
) | ||
``` | ||
|
||
Execute the training runs on a Slurm cluster via `srun`. | ||
Basic implementation of the Asynchronous Successive Halving Algorithm (ASHA) to prune unpromising trials. | ||
|
||
**Arguments**: | ||
|
||
- `metric` - The metric you want to optimize. | ||
- `mode` - Should the metric be minimized or maximized? Allowed values: ["min", "max"] | ||
- `reduction_factor` - The reduction factor of the algorithm | ||
- `min_t` - Minimum number of iterations before we consider pruning. | ||
- `max_t` - Maximum number of iterations. | ||
|
||
#### `ASHA.metric` | ||
|
||
```python | ||
@property | ||
def metric() -> str | ||
``` | ||
|
||
The metric to optimize. | ||
|
||
#### `ASHA.mode` | ||
|
||
```python | ||
@property | ||
def mode() -> str | ||
``` | ||
|
||
The 'mode' of the metric, either 'max' or 'min'. | ||
|
||
#### `ASHA.find_trials_to_prune` | ||
|
||
```python | ||
def find_trials_to_prune(database: "pd.DataFrame") -> List[str] | ||
``` | ||
|
||
Check the database and find trials to prune. | ||
|
||
**Arguments**: | ||
|
||
- `database` - The experiment's metrics table of the database as a pandas DataFrame. | ||
|
||
|
||
Pass an instance of this class to your experiment. | ||
**Returns**: | ||
|
||
List of trial ids that should be pruned. | ||
|
||
### CLASS `slurm_sweeps.SlurmCfg` | ||
|
||
```python | ||
@dataclass | ||
class SlurmCfg: | ||
exclusive: bool = True | ||
nodes: int = 1 | ||
ntasks: int = 1 | ||
args: str = "" | ||
``` | ||
|
||
A configuration class for the SlurmBackend. | ||
|
||
**Arguments**: | ||
|
||
|
@@ -181,5 +251,25 @@ Pass an instance of this class to your experiment. | |
- `ntasks` - How many tasks do you request for your srun? | ||
- `args` - Additional command line arguments for srun, formatted as a string. | ||
|
||
### FUNCTION `slurm_sweeps.log` | ||
|
||
```python | ||
def log(metrics: Dict[str, Union[float, int]], iteration: int) | ||
``` | ||
|
||
Log metrics to the database. | ||
|
||
If ASHA is configured, this also checks if the trial needs to be pruned. | ||
|
||
**Arguments**: | ||
|
||
- `metrics` - A dictionary containing the metrics. | ||
- `iteration` - Iteration of the metrics. Most of the time this will be the epoch. | ||
|
||
**Raises**: | ||
|
||
- `TrialPruned` if the holy ASHA says so! | ||
- `TypeError` if a metric is not of type `float` or `int`. | ||
|
||
## Contact | ||
David Carreto Fidalgo ([email protected]) |