-
Notifications
You must be signed in to change notification settings - Fork 0
/
normalize_map_text.py
321 lines (268 loc) · 12 KB
/
normalize_map_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import re
import unicodedata
from collections import OrderedDict
import numpy as np
from num2words import num2words
abbreviations = {
"bl.a.": "bland annat",
"kungl. maj:t": "kunglig majestät",
"kl.": "klockan",
"fr.o.m.": "från och med",
}
ocr_corrections = {
"$": "§",
"bl.a ": "bl.a.",
r"[D|d]\.v\.s ": "d.v.s. ",
r"[D|d]\. v\.s.": "d.v.s.",
"[F|f]r\.o\.m ": "fr.o.m. ",
"[K|k]ungl\. maj\: t": "kungl. maj:t",
"m. m.": "m.m.",
"m.m ": "m.m. ",
"m. fl.": "m.fl.",
"milj. kr.": "milj.kr.",
"o. s.v.": "o.s.v.",
"s. k.": "s.k.",
"t.o.m,": "t.o.m.",
"t.o. m.": "t.o.m.",
}
def format_abbreviations():
"""
Formats abbreviations into dicts that include the pattern (abbreviation) and replacement (expansion).
"""
abbreviation_patterns = []
for abbreviation, expansion in abbreviations.items():
abbreviation_patterns.append(
{
"pattern": re.escape(abbreviation),
"replacement": expansion,
"transformation_type": "substitution",
}
)
return abbreviation_patterns
def collect_regex_patterns(user_patterns=None):
"""
Collects regex patterns for text normalization and substitution.
Args:
user_patterns (list of dict): User-supplied regex patterns with keys "pattern", "replacement", and "transformation_type".
Returns:
list of dict: Collected regex patterns with default and user-supplied patterns.
"""
patterns = []
# Include abbreviations
patterns.extend(format_abbreviations())
patterns.extend(
[
# Capture pattern groups of type '(digits) kap. (digits) §'. For example "4 kap. 7 §".
# Replace the numbers with ordinals: "fjärde kapitlet sjunde paragrafen"
{
"pattern": r"(\d+) kap\. (\d+) \§",
"replacement": lambda m: f"{num2words(int(m.group(1)),lang='sv',ordinal=True)} kapitlet {num2words(int(m.group(2)),lang='sv', ordinal=True)} paragrafen",
"transformation_type": "substitution",
},
{
"pattern": r"(\d+)[\. ](\d+)",
"replacement": lambda m: f"{num2words(m.group(1) + m.group(2), lang='sv')}",
"transformation_type": "substitution",
},
# Replace : or / between digits with whitespace and num2words the digits
{
"pattern": r"(\d+):(\d+)",
"replacement": lambda m: f"{num2words(int(m.group(1)), lang='sv')} {num2words(int(m.group(2)), lang='sv')}",
"transformation_type": "substitution",
},
# Replace - between digits with " till " and num2words the digits
{
"pattern": r"(\d+)-(\d+)",
"replacement": lambda m: f"{num2words(int(m.group(1)), lang='sv')} till {num2words(int(m.group(2)), lang='sv')}",
"transformation_type": "substitution",
},
# Replace , between digits with " komma " and num2words the digits
{
"pattern": r"(\d+),(\d+)",
"replacement": lambda m: f"{num2words(int(m.group(1)), lang='sv')} komma {num2words(int(m.group(2)), lang='sv')}",
"transformation_type": "substitution",
},
{
"pattern": r"(\d+)[\.\,\:\-\/](\d+)",
"replacement": lambda m: f"{m.group(1)} {m.group(2)}",
"transformation_type": "substitution",
},
# Replace § with 'paragrafen' if preceded by a number
{
"pattern": r"(?<=\d )§",
"replacement": r"paragrafen",
"transformation_type": "substitution",
},
# Replace § with 'paragraf' if succeeded by a number
{
"pattern": r"§(?= \d)",
"replacement": r"paragraf",
"transformation_type": "substitution",
},
# Remove punctuation between whitespace
{"pattern": r"\s[^\w\s]\s", "replacement": " ", "transformation_type": "substitution"},
# Remove punctuation
{"pattern": r"[^\w\s]", "replacement": "", "transformation_type": "deletion"},
# Remove multiple spaces (more than one) with a single space
{"pattern": r"\s{2,}", "replacement": " ", "transformation_type": "substitution"},
# Strip leading and trailing whitespace
{"pattern": r"^\s+|\s+$", "replacement": "", "transformation_type": "deletion"},
# Replace digits with words
{
"pattern": r"(\d+)",
"replacement": lambda m: num2words(int(m.group(1)), lang="sv"),
"transformation_type": "substitution",
},
# Tokenize the rest of the text into words
{
"pattern": r"\w+",
"replacement": lambda m: m.group(),
"transformation_type": "substitution", # Not really a substitution, but we need to record the transformation
},
]
)
# Include user-supplied patterns
if user_patterns:
patterns.extend(user_patterns)
return patterns
def record_transformation(mapping, original_text, start, end, transformation_type, replacement):
"""
Records a transformation in the mapping with additional context for debugging.
Args:
mapping (list of dicts): The list that stores transformation records.
original_text (str): The original text being normalized.
start (int): The start index of the original text span.
end (int): The end index of the original text span.
transformation_type (str): The type of transformation ('substitution', 'deletion', 'insertion').
replacement (str): The replacement text (empty string for deletions).
"""
original_span = original_text[start:end] if start is not None and end is not None else ""
transformation_record = {
"original_start": start,
"original_end": end,
"transformation_type": transformation_type,
"replacement": replacement,
"normalized_start": None, # To be filled in during the apply_transformations step
"normalized_end": None, # To be filled in during the apply_transformations step
"original_token": original_span,
"normalized_token": (
replacement.lower()
if transformation_type == "substitution" and replacement != " "
else None
),
"start_time": None,
"end_time": None,
"index": None,
}
mapping.append(transformation_record)
def apply_transformations(text, mapping):
"""
Applies recorded transformations to the text and updates the mapping with normalized positions.
Args:
text (str): The original text.
mapping (list of dicts): The list of transformations.
Returns:
str: The transformed (normalized) text.
"""
text_length = len(text)
modified = np.zeros(text_length, dtype=bool) # Track modified characters using a boolean mask
offset = 0
normalized_text = text
# Sort transformations by their original start position to ensure correct application order
mapping.sort(key=lambda x: x["original_start"])
for i, transformation in enumerate(mapping):
original_start = transformation["original_start"]
original_end = transformation["original_end"]
if modified[original_start:original_end].any():
# Skip this transformation if it overlaps with a previous transformation
continue
else:
# Mark the characters as modified
modified[original_start:original_end] = True
replacement = transformation["replacement"]
# Calculate the adjusted start and end positions based on the current offset
adjusted_start = original_start + offset
adjusted_end = original_end + offset
# Apply the transformation
normalized_text = (
normalized_text[:adjusted_start] + replacement + normalized_text[adjusted_end:]
)
# Update the normalized spans in the transformation record
transformation["normalized_start"] = adjusted_start
transformation["normalized_end"] = adjusted_start + len(replacement)
# Update the offset for the next transformation
offset += len(replacement) - (original_end - original_start)
transformation["index"] = i
return normalized_text
def normalize_text_with_mapping(text, user_patterns=None, combine_regexes=False):
"""
Normalize speech text transcript while keeping track of transformations.
Args:
text (str): The original text to normalize.
user_patterns (list of dicts, optional): User-supplied regex patterns, replacements, and type of transformation.
Returns:
tuple: Normalized text and list of mappings from original to new text positions.
"""
mapping = []
# Correct some OCR-errors before normalization
for key, value in ocr_corrections.items():
text = re.sub(key, value, text)
# Collect all regex patterns for substitutions and deletions
transformations = collect_regex_patterns(user_patterns)
# Track already matched character spans using a boolean mask
modified_chars = np.zeros(len(text), dtype=bool)
# Record transformations for each pattern match
for pattern_dict in transformations:
pattern = pattern_dict["pattern"]
transformation_type = pattern_dict["transformation_type"]
for match in re.finditer(pattern, text.lower()):
start, end = match.span()
if modified_chars[start:end].any():
# Skip this match if it overlaps with a previous match
continue
else:
# Mark the characters as "to be modified"
modified_chars[start:end] = True
# If pattern_dict["replacement"] is a lambda function, call it to get the replacement string
# Otherwise, use the replacement string
if callable(pattern_dict["replacement"]):
replacement = pattern_dict["replacement"](match)
else:
replacement = pattern_dict["replacement"]
record_transformation(mapping, text, start, end, transformation_type, replacement)
text = unicodedata.normalize("NFKC", text)
# Apply the recorded transformations to the text
normalized_text = apply_transformations(text, mapping)
return normalized_text, mapping
def get_normalized_tokens(mapping, casing="lower"):
normalized_mapping = OrderedDict()
normalized_tokens = []
for i, record in enumerate(mapping):
if record["transformation_type"] == "substitution" and record["replacement"] != " ":
normalized_token = (
record["normalized_token"]
if casing == "lower"
else record["normalized_token"].upper() # Swedish wav2vec2 has uppercase tokens
)
normalized_mapping[i] = {
"token": normalized_token,
"start_time": record["start_time"], # Empty for now
"end_time": record["end_time"], # Empty for now
}
normalized_tokens.append(normalized_token)
return normalized_mapping, normalized_tokens
# Assume timestamps have been added to normalized_mapping
def add_timestamps_to_mapping(mapping, normalized_mapping):
for i, record in enumerate(mapping):
normalized_record = normalized_mapping.get(i)
if normalized_record:
record["start_time"] = normalized_record["start_time"]
record["end_time"] = normalized_record["end_time"]
return mapping
normalized_text, mapping = normalize_text_with_mapping(
"""Vi har bl.a. sett kungl. maj:t vinka till oss - det gjorde han bra.
Kungl. maj: t var glad när han fick 10 233 kronor. Den finns i 4 kap. 7 § i lagen.
Vi samlades i rum 101-105 Fr.o.m kl. 10:00 på morgonen.""",
)
normalized_mapping, normalized_tokens = get_normalized_tokens(mapping)
normalized_text