Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
dcfidalgo committed Dec 11, 2023
1 parent a27f651 commit 5c1c71f
Showing 1 changed file with 112 additions and 22 deletions.
134 changes: 112 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,13 @@ See the `tests` folder for an advanced example of training a PyTorch model with
### `CLASS slurm_sweeps.Experiment`

```python
def __init__(
self,
class Experiment(
train: Callable,
cfg: Dict,
name: str = "MySweep",
local_dir: Union[str, Path] = "./slurm_sweeps",
backend: Optional[Backend] = None,
local_dir: Union[str, Path] = "./slurm-sweeps",
asha: Optional[ASHA] = None,
slurm_cfg: Optional[SlurmCfg] = None,
restore: bool = False,
overwrite: bool = False,
)
Expand All @@ -120,59 +119,130 @@ Set up an HPO experiment.
- `cfg` - A dict passed on to the `train` function.
It must contain the search spaces via `slurm_sweeps.Uniform`, `slurm_sweeps.Choice`, etc.
- `name` - The name of the experiment.
- `local_dir` - Where to store and run the experiments. In this directory
- `local_dir` - Where to store and run the experiments. In this directory,
we will create the database `slurm_sweeps.db` and a folder with the experiment name.
- `backend` - A backend to execute the trials. By default, we choose the `SlurmBackend` if Slurm is available,
otherwise we choose the standard `Backend` that simply executes the trial in another process.
- `slurm_cfg` - The configuration of the Slurm backend responsible for running the trials.
We automatically choose this backend when slurm sweeps is used within an sbatch script.
- `asha` - An optional ASHA instance to cancel less promising trials.
- `restore` - Restore an experiment with the same name?
- `overwrite` - Overwrite an existing experiment with the same name?

#### `Experiment.name`

```python
@property
def name() -> str
```

The name of the experiment.

#### `Experiment.local_dir`

```python
@property
def local_dir() -> Path
```

The local directory of the experiment.

#### `Experiment.run`

```python
def run(
self,
n_trials: int = 1,
max_concurrent_trials: Optional[int] = None,
summary_interval_in_sec: float = 5.0,
nr_of_rows_in_summary: int = 10,
summarize_cfg_and_metrics: Union[bool, List[str]] = True,
summarize_cfg_and_metrics: Union[bool, List[str]] = True
) -> pd.DataFrame
```

Run the experiment.

**Arguments**:

- `n_trials` - Number of trials to run. For grid searches this parameter is ignored.
- `n_trials` - Number of trials to run. For grid searches, this parameter is ignored.
- `max_concurrent_trials` - The maximum number of trials running concurrently. By default, we will set this to
the number of cpus available, or the number of total Slurm tasks divided by the number of trial Slurm
tasks requested.
the number of cpus available, or the number of total Slurm tasks divided by the number of tasks
requested per trial.
- `summary_interval_in_sec` - Print a summary of the experiment every x seconds.
- `nr_of_rows_in_summary` - How many rows of the summary table should we print?
- `summarize_cfg_and_metrics` - Should we include the cfg and the metrics in the summary table?
You can also pass in a list of strings to only select a few cfg and metric keys.

**Returns**:

A DataFrame of the database.
A summary of the trials in a pandas DataFrame.

### `CLASS slurm_sweeps.SlurmBackend`
### `CLASS slurm_sweeps.ASHA`

```python
def __init__(
self,
exclusive: bool = True,
nodes: int = 1,
ntasks: int = 1,
args: str = ""
class ASHA(
metric: str,
mode: str,
reduction_factor: int = 4,
min_t: int = 1,
max_t: int = 50,
)
```

Execute the training runs on a Slurm cluster via `srun`.
Basic implementation of the Asynchronous Successive Halving Algorithm (ASHA) to prune unpromising trials.

**Arguments**:

- `metric` - The metric you want to optimize.
- `mode` - Should the metric be minimized or maximized? Allowed values: ["min", "max"]
- `reduction_factor` - The reduction factor of the algorithm
- `min_t` - Minimum number of iterations before we consider pruning.
- `max_t` - Maximum number of iterations.

#### `ASHA.metric`

```python
@property
def metric() -> str
```

The metric to optimize.

#### `ASHA.mode`

```python
@property
def mode() -> str
```

The 'mode' of the metric, either 'max' or 'min'.

#### `ASHA.find_trials_to_prune`

```python
def find_trials_to_prune(database: "pd.DataFrame") -> List[str]
```

Check the database and find trials to prune.

**Arguments**:

- `database` - The experiment's metrics table of the database as a pandas DataFrame.


Pass an instance of this class to your experiment.
**Returns**:

List of trial ids that should be pruned.

### CLASS `slurm_sweeps.SlurmCfg`

```python
@dataclass
class SlurmCfg:
exclusive: bool = True
nodes: int = 1
ntasks: int = 1
args: str = ""
```

A configuration class for the SlurmBackend.

**Arguments**:

Expand All @@ -181,5 +251,25 @@ Pass an instance of this class to your experiment.
- `ntasks` - How many tasks do you request for your srun?
- `args` - Additional command line arguments for srun, formatted as a string.

### FUNCTION `slurm_sweeps.log`

```python
def log(metrics: Dict[str, Union[float, int]], iteration: int)
```

Log metrics to the database.

If ASHA is configured, this also checks if the trial needs to be pruned.

**Arguments**:

- `metrics` - A dictionary containing the metrics.
- `iteration` - Iteration of the metrics. Most of the time this will be the epoch.

**Raises**:

- `TrialPruned` if the holy ASHA says so!
- `TypeError` if a metric is not of type `float` or `int`.

## Contact
David Carreto Fidalgo ([email protected])

0 comments on commit 5c1c71f

Please sign in to comment.