Skip to content

Commit

Permalink
feat: add Hydra-based "pyannote-audio-train" CLI
Browse files Browse the repository at this point in the history
* closes pyannote#476 (CLI)
* closes pyannote#514 (support for AutoLR)
* closes pyannote#485 (hyper-parameter optimization with Hydra’s Ax sweeper)
* closes pyannote#412 (log graph to Tensorboard)
  • Loading branch information
hbredin authored Nov 18, 2020
1 parent 298ba39 commit 06d3b61
Show file tree
Hide file tree
Showing 28 changed files with 449 additions and 46 deletions.
4 changes: 4 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
recursive-include pyannote *.py
recursive-include pyannote *.yaml
global-exclude *.pyc
global-exclude __pycache__
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# pyannote-audio-v2
Rewriting pyannote.audio from scratch


## CLI

```bash
pyannote-audio-train task=vad model=debug protocol=Debug.SpeakerDiarization.Debug
```


## Contributing

The commands below will setup pre-commit hooks and packages needed for developing the `pyannote.audio` library.
Expand All @@ -12,8 +20,8 @@ pre-commit install

## Testing

Tests rely on a set of debugging files available in [`test/data`](test/data) directory.
Set `PYANNOTE_DATABASE_CONFIG` environment variable to `test/data/database.yml` before running tests:
Tests rely on a set of debugging files available in [`test/data`](test/data) directory.
Set `PYANNOTE_DATABASE_CONFIG` environment variable to `test/data/database.yml` before running tests:

```bash
PYANNOTE_DATABASE_CONFIG=tests/data/database.yml pytest
Expand Down
2 changes: 1 addition & 1 deletion notebook/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
"outputs": [],
"source": [
"from pyannote.audio.tasks.speaker_verification.task import SpeakerEmbeddingArcFace\n",
"emb = SpeakerEmbeddingArcFace(protocol, duration=2., batch_size=32, num_workers=4)\n",
"emb = SpeakerEmbeddingArcFace(protocol, duration=2., num_workers=4)\n",
"model = SimpleEmbeddingModel(task=emb)\n",
"trainer = pl.Trainer(max_epochs=1)\n",
"_ = trainer.fit(model, emb)"
Expand Down
2 changes: 1 addition & 1 deletion notebook/inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
"outputs": [],
"source": [
"from pyannote.audio.tasks.speaker_verification.task import SpeakerEmbeddingArcFace\n",
"emb = SpeakerEmbeddingArcFace(protocol, duration=2., batch_size=32, num_workers=4)\n",
"emb = SpeakerEmbeddingArcFace(protocol, duration=2., num_workers=4)\n",
"from pyannote.audio.models.debug import SimpleEmbeddingModel\n",
"model = SimpleEmbeddingModel(task=emb)\n",
"trainer = pl.Trainer(max_epochs=1, default_root_dir='inference/emb')\n",
Expand Down
Empty file added pyannote/audio/cli/__init__.py
Empty file.
111 changes: 111 additions & 0 deletions pyannote/audio/cli/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# MIT License
#
# Copyright (c) 2020 CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from typing import Iterable

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn import Parameter
from torch.optim import Optimizer

from pyannote.database import FileFinder, get_protocol


@hydra.main(config_path="train_config", config_name="config")
def main(cfg: DictConfig) -> None:

protocol = get_protocol(cfg.protocol, preprocessors={"audio": FileFinder()})

# TODO: configure augmentation
# TODO: configure scheduler
# TODO: configure layer freezing

def optimizer(parameters: Iterable[Parameter], lr: float = 1e-3) -> Optimizer:
return instantiate(cfg.optimizer, parameters, lr=lr)

task = instantiate(
cfg.task,
protocol,
optimizer=optimizer,
learning_rate=cfg.optimizer.lr,
)

model = instantiate(cfg.model, task=task)

monitor, mode = task.validation_monitor
model_checkpoint = ModelCheckpoint(
monitor=monitor,
mode=mode,
save_top_k=10,
period=1,
save_last=True,
save_weights_only=False,
dirpath=".",
filename=f"{{epoch}}-{{{monitor}:.3f}}",
verbose=cfg.verbose,
)

early_stopping = EarlyStopping(
monitor=monitor,
mode=mode,
min_delta=0.0,
patience=10,
strict=True,
verbose=cfg.verbose,
)

logger = TensorBoardLogger(
".",
name="",
version="",
log_graph=True,
)

trainer = instantiate(
cfg.trainer,
callbacks=[model_checkpoint, early_stopping],
logger=logger,
)

if cfg.trainer.auto_lr_find == True:
#  HACK: these two lines below should be removed once
#  the corresponding bug is fixed in pytorch-lighting.
#  https://github.com/pyannote/pyannote-audio/issues/514
task.setup(stage="fit")
model.setup(stage="fit")
trainer.tune(model, task)

trainer.fit(model, task)

best_monitor = float(early_stopping.best_score)
if mode == "min":
return best_monitor
else:
return -best_monitor


if __name__ == "__main__":
main()
Empty file.
10 changes: 10 additions & 0 deletions pyannote/audio/cli/train_config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

protocol: ???
verbose: False

defaults:
- task: vad
- model: debug
- optimizer: adam
- trainer: default
- hydra: train
80 changes: 80 additions & 0 deletions pyannote/audio/cli/train_config/hydra/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# @package _group_

run:
# dir: train/${now:%Y-%m-%d}/${now:%H-%M-%S}
dir: ${protocol}/${task._target_}/${now:%Y-%m-%d}/${now:%H-%M-%S}

sweep:
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}/${protocol}/${task._target_}
subdir: ${hydra.job.num}

output_subdir: ""

help:
app_name: pyannote-audio-train

# Help header, customize to describe your app to your users
header: == ${hydra.help.app_name} ==

footer: |-
Powered by Hydra (https://hydra.cc)
Use --hydra-help to view Hydra specific help
template: |-
${hydra.help.header}
pyannote-audio-train protocol={protocol_name} task={task} model={model}
{task} can be any of the following:
* vad (default) = voice activity detection
* scd = speaker change detection
* osd = overlapped speech detection
* xseg = multi-task segmentation
{model} can be any of the following:
* debug (default) = simple segmentation model for debugging purposes
{optimizer} can be any of the following
* adam (default) = Adam optimizer
{trainer} can be any of the following
* fast_dev_run for debugging
* default (default) for training the model
Options
=======
Here, we describe the most common options: use "--cfg job" option to get a complete list.
* task.duration: audio chunk duration (in seconds)
* task.batch_size: number of audio chunks per batch
* task.num_workers: number of workers used for generating training chunks
* optimizer.lr: learning rate
* trainer.auto_lr_find: use pytorch-lightning AutoLR
Hyper-parameter optimization
============================
Because it is powered by Hydra (https://hydra.cc), one can run grid search using the --multirun option.
For instance, the following command will run the same job three times, with three different learning rates:
pyannote-audio-train --multirun protocol={protocol_name} task={task} optimizer.lr=1e-3,1e-2,1e-1
Even better, one can use Ax (https://ax.dev) sweeper to optimize learning rate directly:
pyannote-audio-train --multirun hydra/sweeper=ax protocol={protocol_name} task={task} optimizer.lr="interval(1e-3, 1e-1)"
See https://hydra.cc/docs/plugins/ax_sweeper for more details.
User-defined task or model
==========================
1. define your_package.YourTask (or your_package.YourModel) class
2. create file /path/to/your_config/task/your_task.yaml (or /path/to/your_config/model/your_model.yaml)
# @package _group_
_target_: your_package.YourTask # or YourModel
param1: value1
param2: value2
3. call pyannote-audio-train --config-dir /path/to/your_config task=your_task task.param1=modified_value1 model=your_model ...
${hydra.help.footer}
2 changes: 2 additions & 0 deletions pyannote/audio/cli/train_config/model/PyanNet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _group_
_target_: pyannote.audio.models.PyanNet
2 changes: 2 additions & 0 deletions pyannote/audio/cli/train_config/model/debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _group_
_target_: pyannote.audio.models.debug.SimpleSegmentationModel
7 changes: 7 additions & 0 deletions pyannote/audio/cli/train_config/optimizer/adam.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package _group_
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False
2 changes: 2 additions & 0 deletions pyannote/audio/cli/train_config/task/osd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _group_
_target_: pyannote.audio.tasks.OverlappedSpeechDetection
2 changes: 2 additions & 0 deletions pyannote/audio/cli/train_config/task/scd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _group_
_target_: pyannote.audio.tasks.SpeakerChangeDetection
2 changes: 2 additions & 0 deletions pyannote/audio/cli/train_config/task/vad.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _group_
_target_: pyannote.audio.tasks.VoiceActivityDetection
2 changes: 2 additions & 0 deletions pyannote/audio/cli/train_config/task/xseg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _group_
_target_: pyannote.audio.tasks.MultiTaskSegmentation
45 changes: 45 additions & 0 deletions pyannote/audio/cli/train_config/trainer/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# @package _group_
_target_: pytorch_lightning.Trainer
#accelerator: None
accumulate_grad_batches: 1
amp_backend: 'native'
amp_level: 'O2'
auto_lr_find: False
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: True
check_val_every_n_epoch: 1
checkpoint_callback: True
deterministic: False
fast_dev_run: False
flush_logs_every_n_steps: 100
#gpus: None
gradient_clip_val: 0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
#log_gpu_memory: None
max_epochs: 1000
#max_steps: None
min_epochs: 1
#min_steps: None
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
prepare_data_per_node: True
process_position: 0
#profiler: None
progress_bar_refresh_rate: 1
reload_dataloaders_every_epoch: False
replace_sampler_ddp: True
sync_batchnorm: False
terminate_on_nan: False
#tpu_cores: None
track_grad_norm: -1
#truncated_bptt_steps: None
val_check_interval: 1.0
#weights_save_path: None
weights_summary: 'top'
3 changes: 3 additions & 0 deletions pyannote/audio/cli/train_config/trainer/fast_dev_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# @package _group_
_target_: pytorch_lightning.Trainer
fast_dev_run: True
9 changes: 9 additions & 0 deletions pyannote/audio/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def build(self):
# (e.g. the final classification and activation layers)
pass

#  used by Tensorboard logger to log model graph
@cached_property
def example_input_array(self) -> torch.Tensor:
return self.task.example_input_array

def helper_introspect(
self,
specifications: TaskSpecification,
Expand Down Expand Up @@ -299,6 +304,10 @@ def setup(self, stage=None):
# so that its dataloader knows how to generate targets
self.task.model_introspection = self.hparams.model_introspection

# this is needed to support pytorch-lightning auto_lr_find feature
# as it expects to find a "learning_rate" entry in model.hparams
self.hparams.learning_rate = self.task.learning_rate

def on_save_checkpoint(self, checkpoint):

#  put everything pyannote.audio-specific under pyannote.audio
Expand Down
Loading

0 comments on commit 06d3b61

Please sign in to comment.