Skip to content

Commit 3b473a3

Browse files
authored
Add audio decoding benchmarks (#580)
1 parent 7ed3779 commit 3b473a3

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import subprocess
2+
3+
from argparse import ArgumentParser
4+
from datetime import timedelta
5+
from pathlib import Path
6+
from time import perf_counter_ns
7+
8+
import torch
9+
import torchaudio
10+
from torch import Tensor
11+
from torchaudio.io import StreamReader
12+
from torchcodec.decoders._audio_decoder import AudioDecoder
13+
14+
DEFAULT_NUM_EXP = 30
15+
16+
17+
def bench(f, *args, num_exp=DEFAULT_NUM_EXP, warmup=1, **kwargs) -> Tensor:
18+
19+
for _ in range(warmup):
20+
f(*args, **kwargs)
21+
22+
times = []
23+
for _ in range(num_exp):
24+
start = perf_counter_ns()
25+
f(*args, **kwargs)
26+
end = perf_counter_ns()
27+
times.append(end - start)
28+
return torch.tensor(times).float()
29+
30+
31+
def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float:
32+
mul = {
33+
"ns": 1,
34+
"µs": 1e-3,
35+
"ms": 1e-6,
36+
"s": 1e-9,
37+
}[unit]
38+
times = times * mul
39+
std = times.std().item()
40+
med = times.median().item()
41+
mean = times.mean().item()
42+
min = times.min().item()
43+
max = times.max().item()
44+
print(
45+
f"{prefix:<40} {med = :.2f}, {mean = :.2f} +- {std:.2f}, {min = :.2f}, {max = :.2f} - in {unit}"
46+
)
47+
48+
49+
def get_duration(path: Path) -> str:
50+
try:
51+
result = subprocess.run(
52+
[
53+
"ffprobe",
54+
"-v",
55+
"error",
56+
"-show_entries",
57+
"format=duration",
58+
"-of",
59+
"default=noprint_wrappers=1:nokey=1",
60+
str(path),
61+
],
62+
stdout=subprocess.PIPE,
63+
stderr=subprocess.PIPE,
64+
text=True,
65+
)
66+
67+
# Remove microseconds
68+
return str(timedelta(seconds=float(result.stdout.strip()))).split(".")[0]
69+
except Exception:
70+
return "?"
71+
72+
73+
def decode_with_torchcodec(path: Path) -> None:
74+
AudioDecoder(path).get_samples_played_in_range(start_seconds=0, stop_seconds=None)
75+
76+
77+
def decode_with_torchaudio_StreamReader(path: Path) -> None:
78+
reader = StreamReader(path)
79+
reader.add_audio_stream(frames_per_chunk=1024)
80+
for _ in reader.stream():
81+
pass
82+
83+
84+
def decode_with_torchaudio_load(path: Path, backend: str) -> None:
85+
torchaudio.load(str(path), backend=backend)
86+
87+
88+
parser = ArgumentParser()
89+
parser.add_argument("--path", type=str, help="path to file", required=True)
90+
parser.add_argument(
91+
"--num-exp",
92+
type=int,
93+
default=DEFAULT_NUM_EXP,
94+
help="number of runs to average over",
95+
)
96+
97+
args = parser.parse_args()
98+
path = Path(args.path)
99+
100+
101+
print(
102+
f"Benchmarking {path.name}, duration: {get_duration(path)}, averaging over {args.num_exp} runs:"
103+
)
104+
105+
times = bench(decode_with_torchcodec, path, num_exp=args.num_exp)
106+
report_stats(times, prefix="torchcodec.AudioDecoder")
107+
108+
times = bench(decode_with_torchaudio_load, path, backend="ffmpeg", num_exp=args.num_exp)
109+
report_stats(times, prefix="torchaudio.load(backend='ffmpeg')")
110+
111+
prefix = "torchaudio.load(backend='sox')"
112+
try:
113+
times = bench(
114+
decode_with_torchaudio_load, path, backend="sox", num_exp=args.num_exp
115+
)
116+
report_stats(times, prefix=prefix)
117+
except RuntimeError:
118+
print(f"{prefix:<40} Not supported")
119+
120+
times = bench(decode_with_torchaudio_StreamReader, path, num_exp=args.num_exp)
121+
report_stats(times, prefix="torchaudio.StreamReader")

0 commit comments

Comments
 (0)