Skip to content

Commit a8f2457

Browse files
committed
Lhotse/K2 support
1 parent 6f5f6a4 commit a8f2457

File tree

4 files changed

+297
-1
lines changed

4 files changed

+297
-1
lines changed

espresso/data/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .asr_bucket_pad_length_dataset import FeatBucketPadLengthDataset, TextBucketPadLengthDataset
77
from .asr_chain_dataset import AsrChainDataset, NumeratorGraphDataset
88
from .asr_dataset import AsrDataset
9+
from .asr_k2_dataset import AsrK2Dataset
910
from .asr_dictionary import AsrDictionary
1011
from .asr_xent_dataset import AliScpCachedDataset, AsrXentDataset
1112
from .feat_text_dataset import (
@@ -20,6 +21,7 @@
2021
"AsrChainDataset",
2122
"AsrDataset",
2223
"AsrDictionary",
24+
"AsrK2Dataset",
2325
"AsrTextDataset",
2426
"AsrXentDataset",
2527
"FeatBucketPadLengthDataset",

espresso/data/asr_k2_dataset.py

+258
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) Yiming Wang
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
import os
8+
import re
9+
from typing import Dict, List
10+
11+
import numpy as np
12+
13+
import torch
14+
15+
from fairseq.data import FairseqDataset, data_utils
16+
17+
import espresso.tools.utils as speech_utils
18+
try:
19+
# TODO use pip install once it's available
20+
from espresso.tools.lhotse.cut import CutSet
21+
except ImportError:
22+
raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools")
23+
24+
25+
def collate(samples, pad_to_length=None, pad_to_multiple=1):
26+
if len(samples) == 0:
27+
return {}
28+
29+
def merge(key, pad_to_length=None):
30+
if key == "source":
31+
return speech_utils.collate_frames(
32+
[sample[key] for sample in samples], 0.0,
33+
pad_to_length=pad_to_length,
34+
pad_to_multiple=pad_to_multiple,
35+
)
36+
else:
37+
raise ValueError("Invalid key.")
38+
39+
id = torch.LongTensor([sample["id"] for sample in samples])
40+
src_frames = merge(
41+
"source",
42+
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
43+
)
44+
# sort by descending source length
45+
if pad_to_length is not None:
46+
src_lengths = torch.IntTensor(
47+
[sample["source"].ne(0.0).any(dim=1).int().sum() for sample in samples]
48+
)
49+
else:
50+
src_lengths = torch.IntTensor([s["source"].size(0) for s in samples])
51+
src_lengths, sort_order = src_lengths.sort(descending=True)
52+
id = id.index_select(0, sort_order)
53+
utt_id = [samples[i]["utt_id"] for i in sort_order.numpy()]
54+
src_frames = src_frames.index_select(0, sort_order)
55+
ntokens = src_lengths.sum().item()
56+
57+
target = None
58+
if samples[0].get("target", None) is not None and len(samples[0].target) > 0:
59+
# reorder the list of samples to make things easier
60+
# (no need to reorder every element in target)
61+
samples = [samples[i] for i in sort_order.numpy()]
62+
63+
from torch.utils.data._utils.collate import default_collate
64+
65+
dataset_idx_to_batch_idx = {
66+
sample["target"][0]["sequence_idx"]: batch_idx
67+
for batch_idx, sample in enumerate(samples)
68+
}
69+
70+
def update(d: Dict, **kwargs) -> Dict:
71+
for key, value in kwargs.items():
72+
d[key] = value
73+
return d
74+
75+
target = default_collate([
76+
update(sup, sequence_idx=dataset_idx_to_batch_idx[sup["sequence_idx"]])
77+
for sample in samples
78+
for sup in sample["target"]
79+
])
80+
81+
batch = {
82+
"id": id,
83+
"utt_id": utt_id,
84+
"nsentences": len(samples),
85+
"ntokens": ntokens,
86+
"net_input": {
87+
"src_tokens": src_frames,
88+
"src_lengths": src_lengths,
89+
},
90+
"target": target,
91+
}
92+
return batch
93+
94+
95+
class AsrK2Dataset(FairseqDataset):
96+
"""
97+
A K2 Dataset for ASR.
98+
99+
Args:
100+
cuts (lhotse.CutSet): Lhotse CutSet to wrap
101+
shuffle (bool, optional): shuffle dataset elements before batching
102+
(default: True).
103+
pad_to_multiple (int, optional): pad src lengths to a multiple of this value
104+
"""
105+
106+
def __init__(
107+
self,
108+
cuts: CutSet,
109+
shuffle=True,
110+
pad_to_multiple=1,
111+
):
112+
self.cuts = cuts
113+
self.cut_ids = list(self.cuts.ids)
114+
self.src_sizes = np.array(
115+
[cut.num_frames if cut.has_features else cut.num_samples for cut in cuts]
116+
)
117+
self.tgt_sizes = None
118+
first_cut = cuts[self.cut_ids[0]]
119+
# assume all cuts have no supervisions if the first one does not
120+
if len(first_cut.supervisions) > 0:
121+
assert len(first_cut.supervisions) == 1, "Only single-supervision cuts are allowed"
122+
assert first_cut.frame_shift is not None, "features are not available in cuts"
123+
self.tgt_sizes = np.array(
124+
[
125+
round(
126+
cut.supervisions[0].trim(cut.duration).duration / cut.frame_shift
127+
) for cut in cuts
128+
]
129+
)
130+
self.shuffle = shuffle
131+
self.epoch = 1
132+
self.sizes = (
133+
np.vstack((self.src_sizes, self.tgt_sizes)).T
134+
if self.tgt_sizes is not None
135+
else self.src_sizes
136+
)
137+
self.pad_to_multiple = pad_to_multiple
138+
self.feat_dim = self.cuts[self.cut_ids[0]].num_features
139+
140+
def __getitem__(self, index):
141+
cut_id = self.cut_ids[index]
142+
cut = self.cuts[cut_id]
143+
features = torch.from_numpy(cut.load_features())
144+
145+
example = {
146+
"id": index,
147+
"utt_id": cut_id,
148+
"source": features,
149+
"target": [
150+
{
151+
"sequence_idx": index,
152+
"text": sup.text,
153+
"start_frame": round(sup.start / cut.frame_shift),
154+
"num_frames": round(sup.duration / cut.frame_shift),
155+
}
156+
# CutSet's supervisions can exceed the cut, when the cut starts/ends in the middle
157+
# of a supervision (they would have relative times e.g. -2 seconds start, meaning
158+
# it started 2 seconds before the Cut starts). We use s.trim() to get rid of that
159+
# property, ensuring the supervision time span does not exceed that of the cut.
160+
for sup in (s.trim(cut.duration) for s in cut.supervisions)
161+
]
162+
}
163+
return example
164+
165+
def __len__(self):
166+
return len(self.cuts)
167+
168+
def collater(self, samples, pad_to_length=None):
169+
"""Merge a list of samples to form a mini-batch.
170+
171+
Args:
172+
samples (List[dict]): samples to collate
173+
pad_to_length (dict, optional): a dictionary of
174+
{"source": source_pad_to_length}
175+
to indicate the max length to pad to in source and target respectively.
176+
177+
Returns:
178+
dict: a mini-batch with the following keys:
179+
180+
- `id` (LongTensor): example IDs in the original input order
181+
- `utt_id` (List[str]): list of utterance ids
182+
- `nsentences` (int): batch size
183+
- `ntokens` (int): total number of tokens in the batch
184+
- `net_input` (dict): the input to the Model, containing keys:
185+
186+
- `src_tokens` (FloatTensor): a padded 3D Tensor of features in
187+
the source of shape `(bsz, src_len, feat_dim)`.
188+
- `src_lengths` (IntTensor): 1D Tensor of the unpadded
189+
lengths of each source sequence of shape `(bsz)`
190+
191+
- `target` (List[Dict[str, Any]]): an List representing a batch of
192+
supervisions
193+
"""
194+
return collate(
195+
samples, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple,
196+
)
197+
198+
def num_tokens(self, index):
199+
"""Return the number of frames in a sample. This value is used to
200+
enforce ``--max-tokens`` during batching."""
201+
return self.src_sizes[index]
202+
203+
def size(self, index):
204+
"""Return an example's size as a float or tuple. This value is used when
205+
filtering a dataset with ``--max-positions``."""
206+
return (
207+
self.src_sizes[index],
208+
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
209+
)
210+
211+
def ordered_indices(self):
212+
"""Return an ordered list of indices. Batches will be constructed based
213+
on this order."""
214+
if self.shuffle:
215+
indices = np.random.permutation(len(self)).astype(np.int64)
216+
else:
217+
indices = np.arange(len(self), dtype=np.int64)
218+
# sort by target length, then source length
219+
if self.tgt_sizes is not None:
220+
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
221+
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
222+
223+
@property
224+
def supports_prefetch(self):
225+
return False
226+
227+
def filter_indices_by_size(self, indices, max_sizes):
228+
"""Filter a list of sample indices. Remove those that are longer
229+
than specified in max_sizes.
230+
231+
Args:
232+
indices (np.array): original array of sample indices
233+
max_sizes (int or list[int] or tuple[int]): max sample size,
234+
can be defined separately for src and tgt (then list or tuple)
235+
236+
Returns:
237+
np.array: filtered sample array
238+
list: list of removed indices
239+
"""
240+
return data_utils.filter_paired_dataset_indices_by_size(
241+
self.src_sizes,
242+
self.tgt_sizes,
243+
indices,
244+
max_sizes,
245+
)
246+
247+
@property
248+
def supports_fetch_outside_dataloader(self):
249+
"""Whether this dataset supports fetching outside the workers of the dataloader."""
250+
return False
251+
252+
@property
253+
def can_reuse_epoch_itr_across_epochs(self):
254+
return False # to avoid running out of CPU RAM
255+
256+
def set_epoch(self, epoch):
257+
super().set_epoch(epoch)
258+
self.epoch = epoch

espresso/tasks/speech_recognition_hybrid.py

+30
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from espresso.data import (
2424
AliScpCachedDataset,
2525
AsrChainDataset,
26+
AsrK2Dataset,
2627
AsrXentDataset,
2728
AsrDictionary,
2829
AsrTextDataset,
@@ -74,6 +75,7 @@ class SpeechRecognitionHybridConfig(FairseqDataclass):
7475
},
7576
)
7677
feat_in_channels: int = field(default=1, metadata={"help": "feature input channels"})
78+
use_k2_dataset: bool = field(default=False, metadata={"help": "if True use K2 dataset"})
7779
specaugment_config: Optional[str] = field(
7880
default=None,
7981
metadata={
@@ -146,6 +148,22 @@ class SpeechRecognitionHybridConfig(FairseqDataclass):
146148
max_epoch: int = II("optimization.max_epoch") # to determine whether in trainig stage
147149

148150

151+
def get_k2_dataset_from_json(data_path, split, shuffle=True, pad_to_multiple=1, seed=1):
152+
try:
153+
# TODO use pip install once it's available
154+
from espresso.tools.lhotse.cut import CutSet
155+
except ImportError:
156+
raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools")
157+
158+
data_json_path = os.path.join(data_path, "cuts_{}.json".format(split))
159+
if not os.path.isfile(data_json_path):
160+
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
161+
162+
cut_set = CutSet.from_json(data_json_path)
163+
logger.info("{} {} examples".format(data_json_path, len(cut_set)))
164+
return AsrK2Dataset(cut_set, shuffle=shuffle, pad_to_multiple=pad_to_multiple)
165+
166+
149167
def get_asr_dataset_from_json(
150168
data_path,
151169
split,
@@ -343,6 +361,7 @@ def __init__(self, cfg: DictConfig, dictionary):
343361
super().__init__(cfg)
344362
self.dictionary = dictionary
345363
self.feat_in_channels = cfg.feat_in_channels
364+
self.use_k2_dataset = cfg.use_k2_dataset
346365
self.specaugment_config = cfg.specaugment_config
347366
self.num_targets = cfg.num_targets
348367
self.training_stage = (cfg.max_epoch > 0) # a hack
@@ -402,6 +421,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
402421
paths = paths[:1]
403422
data_path = paths[(epoch - 1) % len(paths)]
404423

424+
if self.use_k2_dataset:
425+
self.datasets[split] = get_k2_dataset_from_json(
426+
data_path,
427+
split,
428+
shuffle=(split != self.cfg.gen_subset),
429+
pad_to_multiple=self.cfg.required_seq_len_multiple,
430+
seed=self.cfg.seed,
431+
)
432+
self.feat_dim = self.datasets[split].feat_dim
433+
return
434+
405435
self.datasets[split] = get_asr_dataset_from_json(
406436
data_path,
407437
split,

espresso/tools/Makefile

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
KALDI =
2-
PYTHON_DIR = ~/anaconda3/bin
2+
PYTHON_DIR = /export/b03/ywang/anaconda3/bin
33

44
CXX ?= g++
55

@@ -30,6 +30,7 @@ kaldi:
3030
endif
3131

3232
clean: openfst_cleaned
33+
rm -rf lhotse
3334
rm -rf pychain
3435
rm -rf kaldi
3536

@@ -79,3 +80,8 @@ pychain:
7980
export PATH=$(PYTHON_DIR):$$PATH && \
8081
cd pychain/openfst_binding && python3 setup.py install && \
8182
cd ../pytorch_binding && python3 setup.py install
83+
84+
.PHONY: lhotse
85+
lhotse:
86+
test -d lhotse || git clone https://github.com/lhotse-speech/lhotse.git
87+
export PATH=$(PYTHON_DIR):$$PATH && cd lhotse && pip install -e .

0 commit comments

Comments
 (0)