Skip to content

Commit fa07059

Browse files
committed
add random split of negatives
1 parent 7dd7fee commit fa07059

File tree

1 file changed

+114
-19
lines changed

1 file changed

+114
-19
lines changed

examples/mobvoihotwords/local/data_prep.py

+114-19
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
import os
1010
import sys
11+
from typing import List
1112
from concurrent.futures import ProcessPoolExecutor
1213
from pathlib import Path
1314

@@ -31,7 +32,15 @@ def get_parser():
3132
parser.add_argument("--data-dir", default="data", type=str, help="data directory")
3233
parser.add_argument("--seed", default=1, type=int, help="random seed")
3334
parser.add_argument(
34-
"--nj", default=1, type=int, help="number of jobs for features extraction"
35+
"--num-jobs", default=1, type=int, help="number of jobs for features extraction"
36+
)
37+
parser.add_argument(
38+
"--max-remaining-duration", default=0.3, type=float,
39+
help="not split if the left-over duration is less than this many seconds"
40+
)
41+
parser.add_argument(
42+
"--overlap-duration", default=0.3, type=float,
43+
help="overlap between adjacent segments while splitting negative recordings"
3544
)
3645
# fmt: on
3746

@@ -41,7 +50,9 @@ def get_parser():
4150
def main(args):
4251
try:
4352
# TODO use pip install once it's available
44-
from espresso.tools.lhotse import CutSet, Mfcc, MfccConfig, LilcomFilesWriter, WavAugmenter
53+
from espresso.tools.lhotse import (
54+
CutSet, Mfcc, MfccConfig, LilcomFilesWriter, SupervisionSet, WavAugmenter
55+
)
4556
from espresso.tools.lhotse.manipulation import combine
4657
from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords
4758
except ImportError:
@@ -68,36 +79,46 @@ def main(args):
6879
np.random.seed(args.seed)
6980
# equivalent to Kaldi's mfcc_hires config
7081
mfcc = Mfcc(config=MfccConfig(num_mel_bins=40, num_ceps=40, low_freq=20, high_freq=-400))
71-
num_jobs = args.nj
7282
for partition, manifests in mobvoihotwords_manifests.items():
7383
cut_set = CutSet.from_manifests(
7484
recordings=manifests["recordings"],
7585
supervisions=manifests["supervisions"],
7686
)
7787
sampling_rate = next(iter(cut_set)).sampling_rate
78-
with ProcessPoolExecutor(num_jobs) as ex:
88+
with ProcessPoolExecutor(args.num_jobs) as ex:
7989
if "train" in partition:
80-
# original set
81-
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_orig") as storage:
82-
cut_set_orig = cut_set.compute_and_store_features(
90+
# split negative recordings into smaller chunks with lengths sampled from
91+
# length distribution of positive recordings
92+
pos_durs = get_positive_durations(manifests["supervisions"])
93+
with numpy_seed(args.seed):
94+
cut_set = keep_positives_and_split_negatives(
95+
cut_set,
96+
pos_durs,
97+
max_remaining_duration=args.max_remaining_duration,
98+
overlap_duration=args.overlap_duration,
99+
)
100+
# "clean" set
101+
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_clean") as storage:
102+
cut_set_clean = cut_set.compute_and_store_features(
83103
extractor=mfcc,
84104
storage=storage,
85105
augmenter=None,
86106
executor=ex,
87107
)
88-
# augmented with reverbration
89-
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage:
90-
cut_set_rev = cut_set.compute_and_store_features(
91-
extractor=mfcc,
92-
storage=storage,
93-
augmenter=WavAugmenter(effect_chain=reverb()),
94-
excutor=ex,
95-
)
108+
# augmented with reverberation
109+
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage:
110+
with numpy_seed(args.seed):
111+
cut_set_rev = cut_set.compute_and_store_features(
112+
extractor=mfcc,
113+
storage=storage,
114+
augmenter=WavAugmenter(effect_chain=reverb()),
115+
excutor=ex,
116+
)
96117
cut_set_rev = CutSet.from_cuts(
97118
cut.with_id("rev-" + cut.id) for cut in cut_set_rev.cuts
98119
)
99120
# augmented with speed perturbation
100-
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage:
121+
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage:
101122
cut_set_sp1p1 = cut_set.compute_and_store_features(
102123
extractor=mfcc,
103124
storage=storage,
@@ -109,7 +130,7 @@ def main(args):
109130
cut_set_sp1p1 = CutSet.from_cuts(
110131
cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1.cuts
111132
)
112-
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage:
133+
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage:
113134
cut_set_sp0p9 = cut_set.compute_and_store_features(
114135
extractor=mfcc,
115136
storage=storage,
@@ -121,9 +142,9 @@ def main(args):
121142
cut_set_sp0p9 = CutSet.from_cuts(
122143
cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9.cuts
123144
)
124-
# combine the original and augmented sets together
145+
# combine the clean and augmented sets together
125146
cut_set = combine(
126-
cut_set_orig, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9
147+
cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9
127148
)
128149
else: # no augmentations for dev and test sets
129150
with LilcomFilesWriter(f"{output_dir}/feats_{partition}") as storage:
@@ -137,6 +158,80 @@ def main(args):
137158
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
138159

139160

161+
def get_positive_durations(sup_set: SupervisionSet) -> List[float]:
162+
"""
163+
Get duration values of all positive recordings. Assume Supervison.text is
164+
"FREETEXT" for all negative recordings, and SupervisionSegment.duration
165+
equals to the corresponding Recording.duration.
166+
"""
167+
return [sup.dur for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")]
168+
169+
170+
def keep_positives_and_split_negatives(
171+
cut_set: CutSet,
172+
durations: List[float],
173+
max_remaining_duration: float = 0.3,
174+
overlap_duration: float = 0.3,
175+
) -> CutSet:
176+
"""
177+
Returns a new CutSet where all the positives are directly taken from the original
178+
input cut set, and the negatives are obtained by splitting original negatives
179+
into shorter chunks of random lengths drawn from the given length distribution
180+
(here it is the empirical distribution of the positive recordings), There can
181+
be overlap between chunks.
182+
183+
Args:
184+
cut_set (CutSet): original input cut set
185+
durations (list[float]): list of durations to sample from
186+
max_remaining_duration (float, optional): not split if the left-over
187+
duration is less than this many seconds (default: 0.3).
188+
overlap_duration (float, optional): overlap between adjacent segments
189+
(default: None)
190+
191+
Returns:
192+
CutSet: a new cut set after split
193+
"""
194+
assert max_remaining_duration >= 0.0 and overlap_duration >= 0.0
195+
new_cuts = []
196+
for cut in cut_set:
197+
assert len(cut.supervisions) == 1
198+
if cut.supervisions[0].text != "FREETEXT": # keep the positive as it is
199+
new_cuts.append(cut)
200+
else:
201+
this_offset = cut.start
202+
this_offset_relative = this_offset - cut.start
203+
remaining_duration = cut.duration
204+
this_dur = durations[np.random.randint(len(durations))]
205+
while remaining_duration > this_dur + max_remaining_duration:
206+
new_cut = cut.truncate(
207+
offset=this_offset_relative, duration=this_dur, preserve_id=True
208+
)
209+
new_cut = new_cut.with_id(
210+
"{id}-{s:07d}-{e:07d}".format(
211+
id=new_cut.id,
212+
s=int(round(100 * this_offset_relative)),
213+
e=int(round(100 * (this_offset_relative + this_dur)))
214+
)
215+
)
216+
new_cuts.append(new_cut)
217+
this_offset += this_dur - overlap_duration
218+
this_offset_relative = this_offset - cut.start
219+
remaining_duration -= this_dur - overlap_duration
220+
this_dur = durations[np.random.randint(len(durations))]
221+
222+
new_cut = cut.truncate(offset=this_offset_relative, preserve_id=True)
223+
new_cut = new_cut.with_id(
224+
"{id}-{s:07d}-{e:07d}".format(
225+
id=new_cut.id,
226+
s=int(round(100 * this_offset_relative)),
227+
e=int(round(100 * cut.duration))
228+
)
229+
)
230+
new_cuts.append(new_cut)
231+
232+
return CutSet.from_cuts(new_cuts)
233+
234+
140235
def reverb(*args, **kwargs):
141236
"""
142237
Returns a reverb effect for wav augmentation.

0 commit comments

Comments
 (0)