Skip to content

Commit

Permalink
Merge pull request #13 from descriptinc/ps/add_weight_to_losses
Browse files Browse the repository at this point in the history
Adds a weight parameter to losses.
  • Loading branch information
pseeth authored Sep 15, 2021
2 parents 57f0fb6 + 7e543df commit 34b0df8
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 7 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.1.2"
__version__ = "0.1.3"
from .core import AudioSignal, STFTParams, Meter, util
from . import metrics
14 changes: 14 additions & 0 deletions audiotools/metrics/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@
from .. import AudioSignal


class L1Loss(nn.L1Loss):
def __init__(self, weight: float = 1.0, **kwargs):
self.weight = weight
super().__init__(**kwargs)

def forward(self, x, y):
if isinstance(x, AudioSignal):
x = x.audio_data
y = y.audio_data
return super().forward(x, y)


class SISDRLoss(nn.Module):
"""
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
Expand Down Expand Up @@ -31,11 +43,13 @@ def __init__(
reduction: str = " mean",
zero_mean: int = True,
clip_min: int = None,
weight: float = 1.0,
):
self.scaling = scaling
self.reduction = reduction
self.zero_mean = zero_mean
self.clip_min = clip_min
self.weight = weight
super().__init__()

def forward(self, x: AudioSignal, y: AudioSignal):
Expand Down
33 changes: 33 additions & 0 deletions audiotools/metrics/spectral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

import numpy as np
from torch import nn

from .. import AudioSignal
Expand All @@ -19,6 +20,7 @@ def __init__(
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
weight: float = 1.0,
):
super().__init__()
self.stft_params = [
Expand All @@ -28,6 +30,7 @@ def __init__(
self.log_weight = log_weight
self.mag_weight = mag_weight
self.clamp_eps = clamp_eps
self.weight = weight

def forward(self, x: AudioSignal, y: AudioSignal):
loss = 0.0
Expand All @@ -53,6 +56,7 @@ def __init__(
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
weight: float = 1.0,
):
super().__init__()
self.stft_params = [
Expand All @@ -63,6 +67,7 @@ def __init__(
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight

def forward(self, x: AudioSignal, y: AudioSignal):
loss = 0.0
Expand All @@ -81,3 +86,31 @@ def forward(self, x: AudioSignal, y: AudioSignal):
)
loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
return loss


class PhaseLoss(nn.Module):
def __init__(
self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0
):
super().__init__()

self.weight = weight
self.stft_params = STFTParams(window_length, hop_length)

def forward(self, x: AudioSignal, y: AudioSignal):
s = self.stft_params
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)

# Take circular difference
diff = x.phase - y.phase
diff[diff < -np.pi] += 2 * np.pi
diff[diff > np.pi] -= -2 * np.pi

# Scale true magnitude to weights in [0, 1]
x_min, x_max = x.magnitude.min(), x.magnitude.max()
weights = (x.magnitude - x_min) / (x_max - x_min)

# Take weighted mean of all phase errors
loss = ((weights * diff) ** 2).mean()
return loss
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="audiotools",
version="0.1.2",
version="0.1.3",
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
Expand Down
21 changes: 21 additions & 0 deletions tests/metrics/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,24 @@ def test_sisdr(scaling, reduction, clip_min, zero_mean):

loss_val_diff = loss(x, y)
assert loss_val_diff > loss_val_identity


def test_l1_loss():
audio_path = "tests/audio/spk/f10_script4_produced.wav"

x = AudioSignal.excerpt(audio_path, duration=1)
y = x.deepcopy()

loss = metrics.distance.L1Loss()

loss_val_identity = loss(x, y)
assert np.allclose(loss_val_identity, 0.0)

# Pass as tensors rather than audio signals
loss_val_identity = loss(x.audio_data, y.audio_data)
assert np.allclose(loss_val_identity, 0.0)

y = AudioSignal.excerpt(audio_path, duration=1)

loss_val_diff = loss(x, y)
assert loss_val_diff > loss_val_identity
22 changes: 17 additions & 5 deletions tests/metrics/test_spectral.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import numpy as np
import pytest
import torch
import torchaudio

import audiotools
from audiotools import AudioSignal
from audiotools import metrics
from audiotools.core import audio_signal


def test_multiscale_stft():
Expand Down Expand Up @@ -53,3 +48,20 @@ def test_mel_spectrogram_loss():

loss_val_diff = loss(x, y)
assert loss_val_diff > loss_val_identity


def test_phase_loss():
audio_path = "tests/audio/spk/f10_script4_produced.wav"

x = AudioSignal.excerpt(audio_path, duration=1)
y = x.deepcopy()

loss = metrics.spectral.PhaseLoss()

loss_val_identity = loss(x, y)
assert np.allclose(loss_val_identity, 0)

y = AudioSignal.excerpt(audio_path, duration=1)

loss_val_diff = loss(x, y)
assert loss_val_diff > loss_val_identity

0 comments on commit 34b0df8

Please sign in to comment.