From eb899cca941361920b62418d676a38d231f84996 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Fri, 23 Aug 2024 08:55:23 -0400 Subject: [PATCH 01/10] init --- i6_models/parts/fsa.py | 135 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 i6_models/parts/fsa.py diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py new file mode 100644 index 00000000..f6048f9b --- /dev/null +++ b/i6_models/parts/fsa.py @@ -0,0 +1,135 @@ +__all__ = ["TorchFsaBuilder", "WeightedFsa"] + +from functools import reduce +from typing import Iterable, NamedTuple, Tuple, TypeVar + +import numpy as np +import torch + +TWeightedFsa = TypeVar("TWeightedFsa", bound="WeightedFsa") + +class WeightedFsa(NamedTuple): + """ + Convenience class that represents an FSA. It supports scaling the weights of the + fsa by simple left-multiplication and moving the tensors to a different device. + It can simply be passed to `i6_native_ops.fbw.fbw_loss` and `i6_native_ops.fast_viterbi.align_viterbi`. + :param num_states: the total number of all states S + :param edges: a [4, E] tensor of edges where each column is an edge consisting + of from-state, to-state, emission idx and the index of the sequence + it belongs to + :param weights: a [E,] tensor of weights for each edge scaled by the tdp_scale + :param start_end_states: a [N, 2] tensor of start and end states for each of the N sequences + """ + num_states: torch.IntTensor + edges: torch.IntTensor + weights: torch.FloatTensor + start_end_states: torch.IntTensor + + def __mul__(self: TWeightedFsa, scale: float) -> TWeightedFsa: + """Multiply the weights, i.e. the third element, with a scale.""" + return WeightedFsa._make( + tensor * scale if i == 2 else tensor + for i, tensor in enumerate(self) + ) + + def to(self: TWeightedFsa, device: str) -> TWeightedFsa: + """Move the tensors to a given device. This wraps around the + PyTorch `Tensor.to(device)` method.""" + return WeightedFsa._make(tensor.to(device) for tensor in self) + + +class TorchFsaBuilder: + """ + Builder class that wraps around the librasr.AllophoneStateFsaBuilder, + bringing the FSAs into the correct format for the `i6_native_ops.fbw.fbw_loss`. + Use of this class requires a working installation of the python package `librasr`. + This class provides an explicit implementation of the `__getstate__` and `__setstate__` + functions, necessary for pickling as the C++-class `librasr.AllophoneStateFsaBuilder` + is not picklable. + :param config_path: path to the RASR fsa exporter config + :param tdp_scale: multiply the weights by this scale + """ + def __init__(self, config_path: str, tdp_scale: float = 1.0): + import librasr + self.config_path = config_path + config = librasr.Configuration() + config.set_from_file(self.config_path) + self.builder = librasr.AllophoneStateFsaBuilder(config) + self.tdp_scale = tdp_scale + + def __getstate__(self): + state = self.__dict__.copy() + del state["builder"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + config = librasr.Configuration() + config.set_from_file(self.config_path) + self.builder = librasr.AllophoneStateFsaBuilder(config) + + def build_single(self, seq_tag: str) -> Tuple[int, int, np.ndarray, np.ndarray]: + """ + Build the FSA for the given sequence tag in the corpus. + :param seq_tag: sequence tag + :return: FSA as a tuple containing + * number of states S + * number of edges E + * integer edge array of shape [E, 3] where each row is an edge + consisting of from-state, to-state and the emission idx + * float weight array of shape [E,] + """ + raw_fsa = self.builder.build_by_segment_name(seq_tag) + return raw_fsa + + def build_batch(self, seq_tags: Iterable[str]) -> TWeightedFsa: + """ + Build and concatenate the FSAs for a batch of sequence tags + and reformat as an input to `i6_native_ops.fbw.fbw_loss`. + Here the FSAs are concatenated to a long FSA with multiple start and + end states corresponding to each single FSA. For the concatenation, + the state IDs of each single FSA are incrememented and made unique in + the batch. + Additionally we apply an optional scale to the weights. + :param seq_tags: an iterable object of sequence tags + :return: a concatenated FSA + """ + def append_fsa(a, b): + edges = torch.from_numpy(np.int32(b[2])).reshape((3, b[1])) + return ( + a[0] + [b[0]], # num states + a[1] + [b[1]], # num edges + torch.hstack([a[2], edges]), # edges + torch.cat([a[3], torch.from_numpy(b[3])]), # weights + ) + + # concatenate all FSAs in the batch into a single one where state ids are not yet unique + fsas = map(self.build_single, seq_tags) + empty_fsa = ([], [], torch.empty((3, 0), dtype=torch.int32), torch.empty((0,))) + num_states, num_edges, all_edges, all_weights = reduce(append_fsa, fsas, empty_fsa) + num_edges = torch.tensor(num_edges, dtype=torch.int32) + num_states = torch.tensor(num_states, dtype=torch.int32) + + # accumulate number of states for each single fsa in order to determine start and end states + # and make states in edge tensor unique to each sequence + cum_num_states = torch.cumsum(num_states, dim=0, dtype=torch.int32) + state_offsets = torch.cat([torch.zeros((1,), dtype=torch.int32), cum_num_states[:-1]]) + start_end_states = torch.vstack([state_offsets, cum_num_states - 1]) + + # add unique sequence ids to the edge tensor and add start states to the states + # in order to make them unique + edge_seq_idxs = torch.repeat_interleave(num_edges) + all_edges[:2, :] += torch.repeat_interleave(state_offsets, num_edges) + all_edges = torch.vstack([all_edges, edge_seq_idxs]) + + out_fsa = WeightedFsa( + cum_num_states[-1], + all_edges, + all_weights, + start_end_states, + ) + + if self.tdp_scale != 1.0: + out_fsa *= self.tdp_scale + + return out_fsa From 43b19b24c8b4fd1a2cb4ec598c7c5be911521c28 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Fri, 23 Aug 2024 09:01:09 -0400 Subject: [PATCH 02/10] black --- i6_models/parts/fsa.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py index f6048f9b..18944235 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/fsa.py @@ -8,6 +8,7 @@ TWeightedFsa = TypeVar("TWeightedFsa", bound="WeightedFsa") + class WeightedFsa(NamedTuple): """ Convenience class that represents an FSA. It supports scaling the weights of the @@ -20,6 +21,7 @@ class WeightedFsa(NamedTuple): :param weights: a [E,] tensor of weights for each edge scaled by the tdp_scale :param start_end_states: a [N, 2] tensor of start and end states for each of the N sequences """ + num_states: torch.IntTensor edges: torch.IntTensor weights: torch.FloatTensor @@ -27,11 +29,8 @@ class WeightedFsa(NamedTuple): def __mul__(self: TWeightedFsa, scale: float) -> TWeightedFsa: """Multiply the weights, i.e. the third element, with a scale.""" - return WeightedFsa._make( - tensor * scale if i == 2 else tensor - for i, tensor in enumerate(self) - ) - + return WeightedFsa._make(tensor * scale if i == 2 else tensor for i, tensor in enumerate(self)) + def to(self: TWeightedFsa, device: str) -> TWeightedFsa: """Move the tensors to a given device. This wraps around the PyTorch `Tensor.to(device)` method.""" @@ -49,8 +48,10 @@ class TorchFsaBuilder: :param config_path: path to the RASR fsa exporter config :param tdp_scale: multiply the weights by this scale """ + def __init__(self, config_path: str, tdp_scale: float = 1.0): import librasr + self.config_path = config_path config = librasr.Configuration() config.set_from_file(self.config_path) @@ -61,7 +62,7 @@ def __getstate__(self): state = self.__dict__.copy() del state["builder"] return state - + def __setstate__(self, state): self.__dict__.update(state) config = librasr.Configuration() @@ -94,6 +95,7 @@ def build_batch(self, seq_tags: Iterable[str]) -> TWeightedFsa: :param seq_tags: an iterable object of sequence tags :return: a concatenated FSA """ + def append_fsa(a, b): edges = torch.from_numpy(np.int32(b[2])).reshape((3, b[1])) return ( From af9b31766a66242c8846c365a22851c2559d67ef Mon Sep 17 00:00:00 2001 From: DanEnergetics Date: Mon, 26 Aug 2024 15:25:09 +0200 Subject: [PATCH 03/10] update typing Co-authored-by: michelwi --- i6_models/parts/fsa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py index 18944235..aef723a8 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/fsa.py @@ -83,7 +83,7 @@ def build_single(self, seq_tag: str) -> Tuple[int, int, np.ndarray, np.ndarray]: raw_fsa = self.builder.build_by_segment_name(seq_tag) return raw_fsa - def build_batch(self, seq_tags: Iterable[str]) -> TWeightedFsa: + def build_batch(self, seq_tags: Iterable[str]) -> WeightedFsa: """ Build and concatenate the FSAs for a batch of sequence tags and reformat as an input to `i6_native_ops.fbw.fbw_loss`. From 97494ca744d4e3d96bbc02dabcad2eade9bc2353 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Mon, 26 Aug 2024 09:27:46 -0400 Subject: [PATCH 04/10] update docs --- i6_models/parts/fsa.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py index aef723a8..06c0b64f 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/fsa.py @@ -15,9 +15,8 @@ class WeightedFsa(NamedTuple): fsa by simple left-multiplication and moving the tensors to a different device. It can simply be passed to `i6_native_ops.fbw.fbw_loss` and `i6_native_ops.fast_viterbi.align_viterbi`. :param num_states: the total number of all states S - :param edges: a [4, E] tensor of edges where each column is an edge consisting - of from-state, to-state, emission idx and the index of the sequence - it belongs to + :param edges: a [4, E] tensor of edges with number of edges E and where each column is an edge + consisting of from-state, to-state, emission idx and the index of the sequence it belongs to :param weights: a [E,] tensor of weights for each edge scaled by the tdp_scale :param start_end_states: a [N, 2] tensor of start and end states for each of the N sequences """ From 40459ee522196eab7948b7580c853a613cce743f Mon Sep 17 00:00:00 2001 From: DanEnergetics Date: Tue, 27 Aug 2024 17:01:05 +0200 Subject: [PATCH 05/10] Apply suggestions from code review Co-authored-by: Albert Zeyer --- i6_models/parts/fsa.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py index 06c0b64f..f44cdc54 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/fsa.py @@ -1,5 +1,6 @@ __all__ = ["TorchFsaBuilder", "WeightedFsa"] +from __future__ import annotations from functools import reduce from typing import Iterable, NamedTuple, Tuple, TypeVar @@ -13,7 +14,7 @@ class WeightedFsa(NamedTuple): """ Convenience class that represents an FSA. It supports scaling the weights of the fsa by simple left-multiplication and moving the tensors to a different device. - It can simply be passed to `i6_native_ops.fbw.fbw_loss` and `i6_native_ops.fast_viterbi.align_viterbi`. + It can simply be passed to :func:`i6_native_ops.fbw.fbw_loss` and :func:`i6_native_ops.fast_viterbi.align_viterbi`. :param num_states: the total number of all states S :param edges: a [4, E] tensor of edges with number of edges E and where each column is an edge consisting of from-state, to-state, emission idx and the index of the sequence it belongs to @@ -26,11 +27,16 @@ class WeightedFsa(NamedTuple): weights: torch.FloatTensor start_end_states: torch.IntTensor - def __mul__(self: TWeightedFsa, scale: float) -> TWeightedFsa: + def __mul__(self: TWeightedFsa, scale: float) -> WeightedFsa: """Multiply the weights, i.e. the third element, with a scale.""" - return WeightedFsa._make(tensor * scale if i == 2 else tensor for i, tensor in enumerate(self)) + return WeightedFsa( + self.num_states, + self.edges, + self.weights * scale, + self.start_end_states, + ) - def to(self: TWeightedFsa, device: str) -> TWeightedFsa: + def to(self, device: str) -> WeightedFsa: """Move the tensors to a given device. This wraps around the PyTorch `Tensor.to(device)` method.""" return WeightedFsa._make(tensor.to(device) for tensor in self) @@ -71,6 +77,7 @@ def __setstate__(self, state): def build_single(self, seq_tag: str) -> Tuple[int, int, np.ndarray, np.ndarray]: """ Build the FSA for the given sequence tag in the corpus. + :param seq_tag: sequence tag :return: FSA as a tuple containing * number of states S From 838cc76bdb386cf741879e1109a7b3577920ece0 Mon Sep 17 00:00:00 2001 From: DanEnergetics Date: Tue, 27 Aug 2024 17:01:49 +0200 Subject: [PATCH 06/10] Apply suggestions from code review Co-authored-by: Albert Zeyer --- i6_models/parts/fsa.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py index f44cdc54..fac19d57 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/fsa.py @@ -7,8 +7,6 @@ import numpy as np import torch -TWeightedFsa = TypeVar("TWeightedFsa", bound="WeightedFsa") - class WeightedFsa(NamedTuple): """ @@ -98,6 +96,7 @@ def build_batch(self, seq_tags: Iterable[str]) -> WeightedFsa: the state IDs of each single FSA are incrememented and made unique in the batch. Additionally we apply an optional scale to the weights. + :param seq_tags: an iterable object of sequence tags :return: a concatenated FSA """ From 99609c751cf6aa9a241c1518bbd28dc184efa508 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Tue, 27 Aug 2024 11:06:38 -0400 Subject: [PATCH 07/10] typing --- i6_models/parts/fsa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py index fac19d57..1ca20ab5 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/fsa.py @@ -25,7 +25,7 @@ class WeightedFsa(NamedTuple): weights: torch.FloatTensor start_end_states: torch.IntTensor - def __mul__(self: TWeightedFsa, scale: float) -> WeightedFsa: + def __mul__(self, scale: float) -> WeightedFsa: """Multiply the weights, i.e. the third element, with a scale.""" return WeightedFsa( self.num_states, From a81a16b59130c9bc07d4880ae0a19d561ead3210 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Tue, 27 Aug 2024 11:27:19 -0400 Subject: [PATCH 08/10] change librasr import --- i6_models/parts/fsa.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py index 1ca20ab5..7ae0328e 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/fsa.py @@ -1,11 +1,13 @@ +from __future__ import annotations + __all__ = ["TorchFsaBuilder", "WeightedFsa"] -from __future__ import annotations from functools import reduce from typing import Iterable, NamedTuple, Tuple, TypeVar import numpy as np import torch +import librasr class WeightedFsa(NamedTuple): @@ -13,6 +15,7 @@ class WeightedFsa(NamedTuple): Convenience class that represents an FSA. It supports scaling the weights of the fsa by simple left-multiplication and moving the tensors to a different device. It can simply be passed to :func:`i6_native_ops.fbw.fbw_loss` and :func:`i6_native_ops.fast_viterbi.align_viterbi`. + :param num_states: the total number of all states S :param edges: a [4, E] tensor of edges with number of edges E and where each column is an edge consisting of from-state, to-state, emission idx and the index of the sequence it belongs to @@ -34,7 +37,7 @@ def __mul__(self, scale: float) -> WeightedFsa: self.start_end_states, ) - def to(self, device: str) -> WeightedFsa: + def to(self, device: torch.device) -> WeightedFsa: """Move the tensors to a given device. This wraps around the PyTorch `Tensor.to(device)` method.""" return WeightedFsa._make(tensor.to(device) for tensor in self) @@ -48,13 +51,12 @@ class TorchFsaBuilder: This class provides an explicit implementation of the `__getstate__` and `__setstate__` functions, necessary for pickling as the C++-class `librasr.AllophoneStateFsaBuilder` is not picklable. + :param config_path: path to the RASR fsa exporter config :param tdp_scale: multiply the weights by this scale """ def __init__(self, config_path: str, tdp_scale: float = 1.0): - import librasr - self.config_path = config_path config = librasr.Configuration() config.set_from_file(self.config_path) @@ -75,7 +77,7 @@ def __setstate__(self, state): def build_single(self, seq_tag: str) -> Tuple[int, int, np.ndarray, np.ndarray]: """ Build the FSA for the given sequence tag in the corpus. - + :param seq_tag: sequence tag :return: FSA as a tuple containing * number of states S @@ -96,7 +98,7 @@ def build_batch(self, seq_tags: Iterable[str]) -> WeightedFsa: the state IDs of each single FSA are incrememented and made unique in the batch. Additionally we apply an optional scale to the weights. - + :param seq_tags: an iterable object of sequence tags :return: a concatenated FSA """ From e5b928d79e09ea45b94b26a31edf177698c99fde Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Wed, 28 Aug 2024 04:59:03 -0400 Subject: [PATCH 09/10] typing --- i6_models/parts/fsa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py index 7ae0328e..f4b2df22 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/fsa.py @@ -3,7 +3,7 @@ __all__ = ["TorchFsaBuilder", "WeightedFsa"] from functools import reduce -from typing import Iterable, NamedTuple, Tuple, TypeVar +from typing import Iterable, NamedTuple, Tuple, Union import numpy as np import torch @@ -37,7 +37,7 @@ def __mul__(self, scale: float) -> WeightedFsa: self.start_end_states, ) - def to(self, device: torch.device) -> WeightedFsa: + def to(self, device: Union[str, torch.device]) -> WeightedFsa: """Move the tensors to a given device. This wraps around the PyTorch `Tensor.to(device)` method.""" return WeightedFsa._make(tensor.to(device) for tensor in self) From 4d5e3c21554d3169694da7ea350e4198e2784cb1 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Thu, 29 Aug 2024 10:21:54 -0400 Subject: [PATCH 10/10] local librasr import + renaming --- i6_models/parts/{fsa.py => rasr_fsa.py} | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) rename i6_models/parts/{fsa.py => rasr_fsa.py} (96%) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/rasr_fsa.py similarity index 96% rename from i6_models/parts/fsa.py rename to i6_models/parts/rasr_fsa.py index f4b2df22..bd328c49 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/rasr_fsa.py @@ -1,13 +1,12 @@ from __future__ import annotations -__all__ = ["TorchFsaBuilder", "WeightedFsa"] +__all__ = ["RasrFsaBuilder", "WeightedFsa"] from functools import reduce from typing import Iterable, NamedTuple, Tuple, Union import numpy as np import torch -import librasr class WeightedFsa(NamedTuple): @@ -43,11 +42,13 @@ def to(self, device: Union[str, torch.device]) -> WeightedFsa: return WeightedFsa._make(tensor.to(device) for tensor in self) -class TorchFsaBuilder: +class RasrFsaBuilder: """ Builder class that wraps around the librasr.AllophoneStateFsaBuilder, bringing the FSAs into the correct format for the `i6_native_ops.fbw.fbw_loss`. Use of this class requires a working installation of the python package `librasr`. + Hence, the package is locally imported in case other classes are accessed from + this module. This class provides an explicit implementation of the `__getstate__` and `__setstate__` functions, necessary for pickling as the C++-class `librasr.AllophoneStateFsaBuilder` is not picklable. @@ -57,6 +58,8 @@ class TorchFsaBuilder: """ def __init__(self, config_path: str, tdp_scale: float = 1.0): + import librasr + self.config_path = config_path config = librasr.Configuration() config.set_from_file(self.config_path) @@ -69,6 +72,8 @@ def __getstate__(self): return state def __setstate__(self, state): + import librasr + self.__dict__.update(state) config = librasr.Configuration() config.set_from_file(self.config_path)