-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathchunker.py
61 lines (52 loc) · 2.42 KB
/
chunker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
import string
import json
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer', use_fast = True)
def segment_merger(filename, max_text_len = 1000):
segments = json.load(open(filename))
text = ''
last_segment = { 'speaker': None }
start_time = None
stop_chars = string.punctuation.replace(',','')
for segment in segments:
early_break = (max_text_len > 0) and (len(text) > max_text_len) and (text[-1] in stop_chars)
if last_segment['speaker'] != segment['speaker'] or early_break:
if text != '':
yield { 'speaker': last_segment['speaker'], 'text': text, 'start': start_time, 'end': last_segment['end'] }
text = segment['text'].lstrip()
start_time = segment['start']
else:
text += segment['text']
last_segment = segment
if text != '':
yield { 'speaker': last_segment['speaker'], 'text': text, 'start': start_time, 'end': last_segment['end'] }
def time_splitter(merged_segments, chunk_size = 300):
start_time = None
text = ''
speakers = []
for segment in merged_segments:
if start_time is None:
start_time = segment['start']
if not segment['speaker'] in speakers: speakers.append(segment['speaker'])
text += f"{segment['speaker']}: {segment['text']}\n"
if segment['end'] - start_time >= chunk_size:
yield { 'text': text, 'start': start_time, 'end': segment['end'], 'speakers': speakers }
start_time = None
text = ''
speakers = []
def main(prefix: str, chunk_size: int = 300, max_text_len: int = 800):
merged_segments = list(segment_merger(prefix+'.diarize.json', max_text_len))
split_segments = list(time_splitter(merged_segments, chunk_size))
max_tokens = 0
with open(prefix+'.chunk.json', 'w') as f:
json.dump(split_segments, f)
for idx, segment in enumerate(split_segments):
logits = tokenizer.encode(segment['text'])
if len(logits) > max_tokens: max_tokens = len(logits)
print(f"Segment {idx}: {len(logits)} tokens, {len(segment['text'])} characters, {int(segment['end']-segment['start'])} seconds")
print(f"Largest chunk was {max_tokens} tokens")
print(f"Wrote {len(split_segments)} chunks to {prefix}.chunk.json")
if __name__ == "__main__":
import fire
fire.Fire(main)