-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
287 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[deps] | ||
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" | ||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8" | ||
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Performance Benchmarks | ||
|
||
## ResNet | ||
|
||
### ResNet50 (Forward Pass) | ||
|
||
Benchmark was run on a single NVIDIA RTX 4050 GPU with 6GB of memory. | ||
|
||
| Batch Size | Best Timing (Flax) | Best Timing (Lux + Reactant) | | ||
| ---------- | ------------------ | ---------------------------- | | ||
| 1 | 0.00403 s | 0.00057587 s | | ||
| 4 | 0.00879 s | 0.000712372 s | | ||
| 32 | 0.05146 s | 0.000810471 s | | ||
| 128 | 0.20071 s | 0.035191948 s | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[tool.poetry] | ||
description = "Benchmarking Lux against Python frameworks" | ||
authors = ["Avik Pal <[email protected]>"] | ||
readme = "../README.md" | ||
package-mode = false | ||
|
||
[tool.poetry.dependencies] | ||
python = "~3.10" | ||
jax = {extras = ["cuda12"], version = "^0.4.35"} | ||
flax = "~0.10.1" | ||
einops = "^0.8.0" | ||
|
||
|
||
[tool.poetry.group.dev.dependencies] | ||
ipython = "^8.29.0" | ||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
build-backend = "poetry.core.masonry.api" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
using ArgParse, BenchmarkTools | ||
import Metalhead | ||
using Lux, Enzyme, Reactant, Random, Boltz | ||
|
||
Reactant.set_default_backend("gpu") | ||
|
||
function parse_commandline() | ||
s = ArgParseSettings() | ||
|
||
#! format: off | ||
@add_arg_table! s begin | ||
"--batch-size" | ||
help = "Batch size" | ||
arg_type = Vector{Int} | ||
default = [1, 4, 32, 128] | ||
|
||
"--model-size" | ||
help = "Model size" | ||
arg_type = Int | ||
default = 50 | ||
end | ||
#! format: on | ||
|
||
return parse_args(s) | ||
end | ||
|
||
function main() | ||
parsed_args = parse_commandline() | ||
dev = xla_device() | ||
|
||
model = Vision.ResNet(parsed_args["model-size"]) | ||
ps, st = Lux.setup(Random.default_rng(), model) | ||
ps_ra = ps |> dev | ||
st_ra = Lux.testmode(st) |> dev | ||
|
||
println("Param count: $(Lux.parameterlength(ps_ra))") | ||
println("State count: $(Lux.statelength(st_ra))") | ||
|
||
timings = Dict{Int, Float64}() | ||
|
||
for b in parsed_args["batch-size"] | ||
println("batch_size=$b") | ||
|
||
x = rand(Float32, 224, 224, 3, b) |> dev | ||
|
||
model_compiled = @compile model(x, ps_ra, st_ra) | ||
|
||
timings[b] = @belapsed begin | ||
y, _ = $(model_compiled)($(x), $(ps_ra), $(st_ra)) | ||
Reactant.synchronize(y) | ||
end | ||
|
||
println("Best timing: $(timings[b]) s") | ||
end | ||
|
||
println(timings) | ||
end | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import argparse | ||
import time | ||
from functools import partial | ||
from typing import Any, Tuple | ||
from collections.abc import Callable, Sequence | ||
|
||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
import jax.random as random | ||
import numpy as np | ||
|
||
ModuleDef = Any | ||
|
||
|
||
class ResNetBlock(nn.Module): | ||
"""ResNet block.""" | ||
|
||
filters: int | ||
conv: ModuleDef | ||
norm: ModuleDef | ||
act: Callable | ||
strides: tuple[int, int] = (1, 1) | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
residual = x | ||
y = self.conv(self.filters, (3, 3), self.strides)(x) | ||
y = self.norm()(y) | ||
y = self.act(y) | ||
y = self.conv(self.filters, (3, 3))(y) | ||
y = self.norm(scale_init=nn.initializers.zeros_init())(y) | ||
|
||
if residual.shape != y.shape: | ||
residual = self.conv(self.filters, (1, 1), self.strides, name="conv_proj")( | ||
residual | ||
) | ||
residual = self.norm(name="norm_proj")(residual) | ||
|
||
return self.act(residual + y) | ||
|
||
|
||
class BottleneckResNetBlock(nn.Module): | ||
"""Bottleneck ResNet block.""" | ||
|
||
filters: int | ||
conv: ModuleDef | ||
norm: ModuleDef | ||
act: Callable | ||
strides: tuple[int, int] = (1, 1) | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
residual = x | ||
y = self.conv(self.filters, (1, 1))(x) | ||
y = self.norm()(y) | ||
y = self.act(y) | ||
y = self.conv(self.filters, (3, 3), self.strides)(y) | ||
y = self.norm()(y) | ||
y = self.act(y) | ||
y = self.conv(self.filters * 4, (1, 1))(y) | ||
y = self.norm(scale_init=nn.initializers.zeros_init())(y) | ||
|
||
if residual.shape != y.shape: | ||
residual = self.conv( | ||
self.filters * 4, (1, 1), self.strides, name="conv_proj" | ||
)(residual) | ||
residual = self.norm(name="norm_proj")(residual) | ||
|
||
return self.act(residual + y) | ||
|
||
|
||
class ResNet(nn.Module): | ||
"""ResNetV1.5.""" | ||
|
||
stage_sizes: Sequence[int] | ||
block_cls: ModuleDef | ||
num_classes: int | ||
num_filters: int = 64 | ||
dtype: Any = jnp.float32 | ||
act: Callable = nn.relu | ||
conv: ModuleDef = nn.Conv | ||
|
||
@nn.compact | ||
def __call__(self, x, train: bool = True): | ||
conv = partial(self.conv, use_bias=False, dtype=self.dtype) | ||
norm = partial( | ||
nn.BatchNorm, | ||
use_running_average=not train, | ||
momentum=0.9, | ||
epsilon=1e-5, | ||
dtype=self.dtype, | ||
axis_name="batch", | ||
) | ||
|
||
x = conv( | ||
self.num_filters, | ||
(7, 7), | ||
(2, 2), | ||
padding=[(3, 3), (3, 3)], | ||
name="conv_init", | ||
)(x) | ||
x = norm(name="bn_init")(x) | ||
x = nn.relu(x) | ||
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") | ||
for i, block_size in enumerate(self.stage_sizes): | ||
for j in range(block_size): | ||
strides = (2, 2) if i > 0 and j == 0 else (1, 1) | ||
x = self.block_cls( | ||
self.num_filters * 2**i, | ||
strides=strides, | ||
conv=conv, | ||
norm=norm, | ||
act=self.act, | ||
)(x) | ||
x = jnp.mean(x, axis=(1, 2)) | ||
x = nn.Dense(self.num_classes, dtype=self.dtype)(x) | ||
x = jnp.asarray(x, self.dtype) | ||
return x | ||
|
||
|
||
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock) | ||
|
||
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock) | ||
|
||
ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock) | ||
|
||
ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock) | ||
|
||
ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock) | ||
|
||
ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--batch-size", type=list, default=[1, 4, 32, 128]) | ||
parser.add_argument("--model-size", type=int, default=50) | ||
args = parser.parse_args() | ||
|
||
if args.model_size == 18: | ||
model = ResNet18 | ||
elif args.model_size == 34: | ||
model = ResNet34 | ||
elif args.model_size == 50: | ||
model = ResNet50 | ||
elif args.model_size == 101: | ||
model = ResNet101 | ||
elif args.model_size == 152: | ||
model = ResNet152 | ||
elif args.model_size == 200: | ||
model = ResNet200 | ||
|
||
model = model(num_classes=1000) | ||
|
||
timings = dict() | ||
|
||
for b in args.batch_size: | ||
print(f"batch_size={b}") | ||
|
||
x = jnp.ones((b, 224, 224, 3), jnp.float32) | ||
params = model.init(random.PRNGKey(0), x, train=False) | ||
param_count = sum(x.size for x in jax.tree.leaves(params)) | ||
|
||
print(f"Param count: {param_count}") | ||
|
||
apply_fn_compiled = jax.jit(partial(model.apply, train=False)) | ||
|
||
best_timing = np.inf | ||
for i in range(25): | ||
t1 = time.time() | ||
apply_fn_compiled(params, x).block_until_ready() | ||
t2 = time.time() | ||
best_timing = min(best_timing, t2 - t1) | ||
|
||
timings[b] = best_timing | ||
print(f"Best timing: {best_timing:.5f} s") | ||
|
||
print(timings) |