Skip to content

Commit

Permalink
use l1 norm instead of softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
Patchethium committed Aug 13, 2023
1 parent b89fdd8 commit 6eaf892
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ I'll cover this part if it's needed by anyone. Please let me know by creating an

- Rust crate
- multi-language
- Storing `pau` index in binary model
- Option to convert frame number into milisecond
- Record and warn the user when score is too low

## Licence

Expand Down
15 changes: 9 additions & 6 deletions src/snfa/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def softmax(x, axis=-1):
e_x = np.exp(x - np.max(x, axis, keepdims=True))
return e_x / np.sum(e_x, axis, keepdims=True)

def l1_normalize(arr, axis=None):
arr = arr - np.min(arr)
norm = np.sum(np.abs(arr), axis=axis, keepdims=True)
normalized_arr = arr / norm
return normalized_arr

def log_softmax(x, axis=-1):
return np.log(softmax(x, axis))
Expand Down Expand Up @@ -147,9 +152,6 @@ def mel(self, x: np.ndarray):
mel = np.fliplr(mel)
return mel

def _norm_labels(self, labels) -> np.ndarray:
return softmax(labels[:, 1:], axis=1)

def get_indices(self, ph):
try:
tokens = np.array([int(self.phone_set.index(p)) for p in ph])
Expand All @@ -162,13 +164,14 @@ def align(self, x, ph):
indices = self.get_indices(ph)

labels = self.model_forward(mel)
labels = self._norm_labels(labels)

trellis = viterbi.get_trellis(labels, indices)
emission = l1_normalize(labels[:, 1:], axis=1)[:, indices]

trellis = viterbi.get_trellis(emission)
path = viterbi.backtrack(trellis)

segments = viterbi.merge_repeats(path, indices)
return segments, path, trellis, labels
return segments, path, trellis, emission, labels

def __call__(self, x: np.ndarray, ph: List[str]):
return self.align(x, ph)
Expand Down
10 changes: 5 additions & 5 deletions src/snfa/viterbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ class Point:
time_index: int


def get_trellis(emission: np.ndarray, tokens: np.ndarray) -> np.ndarray:
def get_trellis(emission: np.ndarray) -> np.ndarray:
"""
Get a cost matrix `trellis` from emission
using Viterbi algorithm.
"""
num_frames, num_tokens = emission.shape[0], tokens.shape[0]
num_frames, num_tokens = emission.shape
trellis = np.zeros((num_frames, num_tokens))
trellis[1:, 0] = np.cumsum(emission[1:, tokens[0]], 0)
trellis[1:, 0] = np.cumsum(emission[1:, 0], 0)
trellis[0, 1:] = -np.inf
trellis[-num_tokens + 1 :, 0] = np.inf

for t in range(num_frames - 1):
candidate = np.maximum(
trellis[t, 1:] + emission[t + 1, tokens[1:]],
trellis[t, :-1] + emission[t + 1, tokens[1:]],
trellis[t, 1:] + emission[t + 1, 1:],
trellis[t, :-1] + emission[t + 1, 1:],
)
trellis[t + 1, 1:] = candidate

Expand Down

0 comments on commit 6eaf892

Please sign in to comment.