-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a8c8a63
commit 5154043
Showing
12 changed files
with
276 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,4 +91,5 @@ target/ | |
|
||
# Special | ||
sync.sh | ||
|
||
casync.sh | ||
gsync.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/usr/bin/env python | ||
import librosa | ||
import numpy as np | ||
|
||
|
||
class SpeechFeatures(object): | ||
def __init__(self, cwlen, cwshift, sf): | ||
self.sf = sf | ||
|
||
def _preprocess(self): | ||
pass | ||
|
||
def _load(self, afile): | ||
signal, _ = librosa.load(afile, sr=self.sf) | ||
return signal | ||
|
||
def load_raw(self, preprocess=True): | ||
pass | ||
|
||
def load_sgram(self, preprocess=False): | ||
pass |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#!/usr/bin/env python | ||
from torch import nn | ||
import torch | ||
from src.models.resnet_base import resnet50 | ||
from src.models.ops import CELoss | ||
from src.utils.math_utils import nextpow2 | ||
from src.utils import torch_utils | ||
|
||
|
||
class SpecNet(nn.Module): | ||
def __init__(self, num_classes, sf, win_size, hop_len, | ||
window=torch.hamming_window): | ||
super().__init__() | ||
self.num_classes = num_classes | ||
self.base = resnet50(num_classes=self.num_classes) | ||
self.criterion = nn.CrossEntropyLoss() | ||
self.loss_obj = CELoss | ||
self.sf = sf | ||
self.win_length = round(1e-3*win_size*self.sf) | ||
self.hop_length = round(1e-3*hop_len*self.sf) | ||
self.n_fft = 2**nextpow2(self.win_length) | ||
self.hop_len = hop_len | ||
|
||
self.window = window(self.win_length, device=torch_utils.device) | ||
|
||
def spectrogram(self, signal: torch.Tensor): | ||
window = self.window | ||
spec = torch.stft(signal, self.n_fft, hop_length=self.hop_length, | ||
win_length=self.win_length, window=window) | ||
mag_spec = spec.pow(2).sum(-1) # Mag Spectrogram | ||
if mag_spec.size(1) != 257: # Debug | ||
raise RuntimeError( | ||
f'Expected SPEC size 257, got {mag_spec.size(2)}') | ||
spec_mean = mag_spec.mean(2, keepdim=True) | ||
spec_std = mag_spec.std(2, keepdim=True) | ||
mag_spec -= spec_mean | ||
mag_spec /= spec_std | ||
return mag_spec.to(torch.float) | ||
|
||
def forward(self, batch): | ||
signal = batch['raw'] | ||
spec = self.spectrogram(signal).unsqueeze(1) | ||
return self.base(spec) | ||
|
||
def loss(self, model_outs, batch): | ||
if self.num_classes == 2: | ||
target = batch['gid'] | ||
else: | ||
target = batch['cid'] | ||
loss = self.criterion(model_outs, target) | ||
metric = CELoss(loss, 1) | ||
return metric, loss |
Oops, something went wrong.