Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RASR FSA Builder #59

Merged
merged 10 commits into from
Sep 2, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions i6_models/parts/fsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
__all__ = ["TorchFsaBuilder", "WeightedFsa"]

DanEnergetics marked this conversation as resolved.
Show resolved Hide resolved
from functools import reduce
from typing import Iterable, NamedTuple, Tuple, TypeVar

import numpy as np
import torch

TWeightedFsa = TypeVar("TWeightedFsa", bound="WeightedFsa")

DanEnergetics marked this conversation as resolved.
Show resolved Hide resolved

class WeightedFsa(NamedTuple):
"""
Convenience class that represents an FSA. It supports scaling the weights of the
fsa by simple left-multiplication and moving the tensors to a different device.
It can simply be passed to `i6_native_ops.fbw.fbw_loss` and `i6_native_ops.fast_viterbi.align_viterbi`.
DanEnergetics marked this conversation as resolved.
Show resolved Hide resolved
:param num_states: the total number of all states S
albertz marked this conversation as resolved.
Show resolved Hide resolved
:param edges: a [4, E] tensor of edges where each column is an edge consisting
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: E (I guess number of edges) is undefined here.

of from-state, to-state, emission idx and the index of the sequence
it belongs to
:param weights: a [E,] tensor of weights for each edge scaled by the tdp_scale
:param start_end_states: a [N, 2] tensor of start and end states for each of the N sequences
"""

num_states: torch.IntTensor
edges: torch.IntTensor
weights: torch.FloatTensor
start_end_states: torch.IntTensor

def __mul__(self: TWeightedFsa, scale: float) -> TWeightedFsa:
DanEnergetics marked this conversation as resolved.
Show resolved Hide resolved
"""Multiply the weights, i.e. the third element, with a scale."""
return WeightedFsa._make(tensor * scale if i == 2 else tensor for i, tensor in enumerate(self))
DanEnergetics marked this conversation as resolved.
Show resolved Hide resolved

def to(self: TWeightedFsa, device: str) -> TWeightedFsa:
DanEnergetics marked this conversation as resolved.
Show resolved Hide resolved
albertz marked this conversation as resolved.
Show resolved Hide resolved
"""Move the tensors to a given device. This wraps around the
PyTorch `Tensor.to(device)` method."""
return WeightedFsa._make(tensor.to(device) for tensor in self)


class TorchFsaBuilder:
michelwi marked this conversation as resolved.
Show resolved Hide resolved
"""
Builder class that wraps around the librasr.AllophoneStateFsaBuilder,
bringing the FSAs into the correct format for the `i6_native_ops.fbw.fbw_loss`.
Use of this class requires a working installation of the python package `librasr`.
This class provides an explicit implementation of the `__getstate__` and `__setstate__`
functions, necessary for pickling as the C++-class `librasr.AllophoneStateFsaBuilder`
is not picklable.
:param config_path: path to the RASR fsa exporter config
:param tdp_scale: multiply the weights by this scale
"""

def __init__(self, config_path: str, tdp_scale: float = 1.0):
import librasr

self.config_path = config_path
config = librasr.Configuration()
config.set_from_file(self.config_path)
self.builder = librasr.AllophoneStateFsaBuilder(config)
self.tdp_scale = tdp_scale

def __getstate__(self):
state = self.__dict__.copy()
del state["builder"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
config = librasr.Configuration()
albertz marked this conversation as resolved.
Show resolved Hide resolved
config.set_from_file(self.config_path)
self.builder = librasr.AllophoneStateFsaBuilder(config)

def build_single(self, seq_tag: str) -> Tuple[int, int, np.ndarray, np.ndarray]:
"""
Build the FSA for the given sequence tag in the corpus.
:param seq_tag: sequence tag
DanEnergetics marked this conversation as resolved.
Show resolved Hide resolved
:return: FSA as a tuple containing
* number of states S
* number of edges E
* integer edge array of shape [E, 3] where each row is an edge
consisting of from-state, to-state and the emission idx
* float weight array of shape [E,]
"""
raw_fsa = self.builder.build_by_segment_name(seq_tag)
return raw_fsa

def build_batch(self, seq_tags: Iterable[str]) -> TWeightedFsa:
DanEnergetics marked this conversation as resolved.
Show resolved Hide resolved
"""
Build and concatenate the FSAs for a batch of sequence tags
and reformat as an input to `i6_native_ops.fbw.fbw_loss`.
Here the FSAs are concatenated to a long FSA with multiple start and
end states corresponding to each single FSA. For the concatenation,
the state IDs of each single FSA are incrememented and made unique in
the batch.
Additionally we apply an optional scale to the weights.
:param seq_tags: an iterable object of sequence tags
DanEnergetics marked this conversation as resolved.
Show resolved Hide resolved
:return: a concatenated FSA
"""

def append_fsa(a, b):
edges = torch.from_numpy(np.int32(b[2])).reshape((3, b[1]))
return (
a[0] + [b[0]], # num states
a[1] + [b[1]], # num edges
torch.hstack([a[2], edges]), # edges
torch.cat([a[3], torch.from_numpy(b[3])]), # weights
)

# concatenate all FSAs in the batch into a single one where state ids are not yet unique
fsas = map(self.build_single, seq_tags)
empty_fsa = ([], [], torch.empty((3, 0), dtype=torch.int32), torch.empty((0,)))
num_states, num_edges, all_edges, all_weights = reduce(append_fsa, fsas, empty_fsa)
num_edges = torch.tensor(num_edges, dtype=torch.int32)
num_states = torch.tensor(num_states, dtype=torch.int32)

# accumulate number of states for each single fsa in order to determine start and end states
# and make states in edge tensor unique to each sequence
cum_num_states = torch.cumsum(num_states, dim=0, dtype=torch.int32)
state_offsets = torch.cat([torch.zeros((1,), dtype=torch.int32), cum_num_states[:-1]])
start_end_states = torch.vstack([state_offsets, cum_num_states - 1])

# add unique sequence ids to the edge tensor and add start states to the states
# in order to make them unique
edge_seq_idxs = torch.repeat_interleave(num_edges)
all_edges[:2, :] += torch.repeat_interleave(state_offsets, num_edges)
all_edges = torch.vstack([all_edges, edge_seq_idxs])

out_fsa = WeightedFsa(
cum_num_states[-1],
all_edges,
all_weights,
start_end_states,
)

if self.tdp_scale != 1.0:
out_fsa *= self.tdp_scale

return out_fsa
Loading