-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
executable file
·245 lines (200 loc) · 9.28 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
#! /usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn.functional as F
import math
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class PreEmphasis(torch.nn.Module):
def __init__(self, coef: float = 0.97):
super().__init__()
self.coef = coef
# make kernel
# In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
self.register_buffer(
'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
)
def forward(self, input: torch.tensor) -> torch.tensor:
assert len(input.size()) == 2, 'The number of dimensions of input tensor must be 2!'
# reflect padding to match lengths of in/out
input = input.unsqueeze(1)
input = F.pad(input, (1, 0), 'reflect')
return F.conv1d(input, self.flipped_filter).squeeze(1)
class Resample(torch.nn.Module):
def __init__(
self, orig_freq=16000, new_freq=16000, lowpass_filter_width=6,
):
super().__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.lowpass_filter_width = lowpass_filter_width
# Compute rate for striding
self._compute_strides()
assert self.orig_freq % self.conv_stride == 0
assert self.new_freq % self.conv_transpose_stride == 0
def _compute_strides(self):
# Compute new unit based on ratio of in/out frequencies
base_freq = math.gcd(self.orig_freq, self.new_freq)
input_samples_in_unit = self.orig_freq // base_freq
self.output_samples = self.new_freq // base_freq
# Store the appropriate stride based on the new units
self.conv_stride = input_samples_in_unit
self.conv_transpose_stride = self.output_samples
def forward(self, waveforms):
if not hasattr(self, "first_indices"):
self._indices_and_weights(waveforms)
# Don't do anything if the frequencies are the same
if self.orig_freq == self.new_freq:
return waveforms
unsqueezed = False
if len(waveforms.shape) == 2:
waveforms = waveforms.unsqueeze(1)
unsqueezed = True
elif len(waveforms.shape) == 3:
waveforms = waveforms.transpose(1, 2)
else:
raise ValueError("Input must be 2 or 3 dimensions")
# Do resampling
resampled_waveform = self._perform_resample(waveforms)
if unsqueezed:
resampled_waveform = resampled_waveform.squeeze(1)
else:
resampled_waveform = resampled_waveform.transpose(1, 2)
return resampled_waveform
def _perform_resample(self, waveforms):
# Compute output size and initialize
batch_size, num_channels, wave_len = waveforms.size()
window_size = self.weights.size(1)
tot_output_samp = self._output_samples(wave_len)
resampled_waveform = torch.zeros(
(batch_size, num_channels, tot_output_samp),
device=waveforms.device,
)
self.weights = self.weights.to(waveforms.device)
# Check weights are on correct device
if waveforms.device != self.weights.device:
self.weights = self.weights.to(waveforms.device)
# eye size: (num_channels, num_channels, 1)
eye = torch.eye(num_channels, device=waveforms.device).unsqueeze(2)
# Iterate over the phases in the polyphase filter
for i in range(self.first_indices.size(0)):
wave_to_conv = waveforms
first_index = int(self.first_indices[i].item())
if first_index >= 0:
# trim the signal as the filter will not be applied
# before the first_index
wave_to_conv = wave_to_conv[..., first_index:]
# pad the right of the signal to allow partial convolutions
# meaning compute values for partial windows (e.g. end of the
# window is outside the signal length)
max_index = (tot_output_samp - 1) // self.output_samples
end_index = max_index * self.conv_stride + window_size
current_wave_len = wave_len - first_index
right_padding = max(0, end_index + 1 - current_wave_len)
left_padding = max(0, -first_index)
wave_to_conv = torch.nn.functional.pad(
wave_to_conv, (left_padding, right_padding)
)
conv_wave = torch.nn.functional.conv1d(
input=wave_to_conv,
weight=self.weights[i].repeat(num_channels, 1, 1),
stride=self.conv_stride,
groups=num_channels,
)
# we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride]
dilated_conv_wave = torch.nn.functional.conv_transpose1d(
conv_wave, eye, stride=self.conv_transpose_stride
)
# pad dilated_conv_wave so it reaches the output length if needed.
left_padding = i
previous_padding = left_padding + dilated_conv_wave.size(-1)
right_padding = max(0, tot_output_samp - previous_padding)
dilated_conv_wave = torch.nn.functional.pad(
dilated_conv_wave, (left_padding, right_padding)
)
dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp]
resampled_waveform += dilated_conv_wave
return resampled_waveform
def _output_samples(self, input_num_samp):
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
# where tick_freq is the least common multiple of samp_in and
# samp_out.
samp_in = int(self.orig_freq)
samp_out = int(self.new_freq)
tick_freq = abs(samp_in * samp_out) // math.gcd(samp_in, samp_out)
ticks_per_input_period = tick_freq // samp_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_in ).
interval_length = input_num_samp * ticks_per_input_period
if interval_length <= 0:
return 0
ticks_per_output_period = tick_freq // samp_out
# Get the last output-sample in the closed interval,
# i.e. replacing [ ) with [ ]. Note: integer division rounds down.
# See http://en.wikipedia.org/wiki/Interval_(mathematics) for an
# explanation of the notation.
last_output_samp = interval_length // ticks_per_output_period
# We need the last output-sample in the open interval, so if it
# takes us to the end of the interval exactly, subtract one.
if last_output_samp * ticks_per_output_period == interval_length:
last_output_samp -= 1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp = last_output_samp + 1
return num_output_samp
def _indices_and_weights(self, waveforms):
# Lowpass filter frequency depends on smaller of two frequencies
min_freq = min(self.orig_freq, self.new_freq)
lowpass_cutoff = 0.99 * 0.5 * min_freq
assert lowpass_cutoff * 2 <= min_freq
window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
output_t = torch.arange(
start=0.0, end=self.output_samples, device=waveforms.device,
)
output_t /= self.new_freq
min_t = output_t - window_width
max_t = output_t + window_width
min_input_index = torch.ceil(min_t * self.orig_freq)
max_input_index = torch.floor(max_t * self.orig_freq)
num_indices = max_input_index - min_input_index + 1
max_weight_width = num_indices.max()
j = torch.arange(max_weight_width, device=waveforms.device)
input_index = min_input_index.unsqueeze(1) + j.unsqueeze(0)
delta_t = (input_index / self.orig_freq) - output_t.unsqueeze(1)
weights = torch.zeros_like(delta_t)
inside_window_indices = delta_t.abs().lt(window_width)
# raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * (
1
+ torch.cos(
2
* math.pi
* lowpass_cutoff
/ self.lowpass_filter_width
* delta_t[inside_window_indices]
)
)
t_eq_zero_indices = delta_t.eq(0.0)
t_not_eq_zero_indices = ~t_eq_zero_indices
# sinc filter function
weights[t_not_eq_zero_indices] *= torch.sin(
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]
) / (math.pi * delta_t[t_not_eq_zero_indices])
# limit of the function at t = 0
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
# size (output_samples, max_weight_width)
weights /= self.orig_freq
self.first_indices = min_input_index
self.weights = weights