Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: benchmarking our models against Jax (Flax) #1000

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,11 @@ benchmarks/results
# Generated by tutorials
pinn_nested_ad.gif
*.mlir

# poetry
poetry.lock
*.egg-info
*.pyc
.venv
.python-version
__pycache__
8 changes: 8 additions & 0 deletions perf/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
32 changes: 32 additions & 0 deletions perf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Performance Benchmarks

## ResNet

Benchmark was run on a single NVIDIA RTX 4050 GPU with 6GB of memory.

### ResNet18 (Forward Pass)

| Batch Size | Best Timing (Flax) | Best Timing (Lux + Reactant) | Best Timing (Lux) |
| ---------- | ------------------ | ---------------------------- | ----------------- |
| 1 | 0.00249 s | 0.000272982 s | 0.002161114 s |
| 4 | 0.00381 s | 0.000322524 s | 0.003498441 s |
| 32 | 0.01796 s | 0.000364948 s | 0.027250628 s |
| 128 | 0.06757 s | 0.000545296 s | 0.115965297 s |

### ResNet 34 (Forward Pass)

| Batch Size | Best Timing (Flax) | Best Timing (Lux + Reactant) | Best Timing (Lux) |
| ---------- | ------------------ | ---------------------------- | ----------------- |
| 1 | 0.00462 s | 0.000547826 s | 0.003684532 s |
| 4 | 0.00696 s | 0.000839503 s | 0.006234771 s |
| 32 | 0.03169 s | 0.000737906 s | 0.046339233 s |
| 128 | 0.12129 s | 0.001383708 s | 0.640747518 s |

### ResNet50 (Forward Pass)

| Batch Size | Best Timing (Flax) | Best Timing (Lux + Reactant) | Best Timing (Lux) |
| ---------- | ------------------ | ---------------------------- | ----------------- |
| 1 | 0.00403 s | 0.001212556 s | 0.004382536 s |
| 4 | 0.00788 s | 0.000745961 s | 0.011562075 s |
| 32 | 0.05146 s | 0.000783826 s | 0.103826668 s |
| 128 | 0.20071 s | 0.001340597 s | 0.430018518 s |
19 changes: 19 additions & 0 deletions perf/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[tool.poetry]
description = "Benchmarking Lux against Python frameworks"
authors = ["Avik Pal <[email protected]>"]
readme = "../README.md"
package-mode = false

[tool.poetry.dependencies]
python = "~3.10"
jax = {extras = ["cuda12"], version = "^0.4.35"}
flax = "~0.10.1"
einops = "^0.8.0"


[tool.poetry.group.dev.dependencies]
ipython = "^8.29.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
35 changes: 35 additions & 0 deletions perf/resnet/lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using Comonicon, BenchmarkTools
using Lux, LuxCUDA, Random

include("resnet.jl")

Comonicon.@main function main(;
batch_size::Vector{Int}=[1, 4, 32, 128], model_size::Int=50
)
dev = gpu_device(; force=true)

model = ResNet(model_size)
ps, st = Lux.setup(Random.default_rng(), model) |> dev

println("Param count: $(Lux.parameterlength(ps))")
println("State count: $(Lux.statelength(st))")

timings = Dict{Int, Float64}()

for b in batch_size
println("batch_size=$b")

x = rand(Float32, 224, 224, 3, b) |> dev

timings[b] = @belapsed begin
y, _ = $(model)($(x), $(ps), $(Lux.testmode(st)))
CUDA.synchronize()
end

println("Best timing: $(timings[b]) s")
end

println(timings)
end

main()
180 changes: 180 additions & 0 deletions perf/resnet/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import argparse
import time
from functools import partial
from typing import Any, Tuple
from collections.abc import Callable, Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np

ModuleDef = Any


class ResNetBlock(nn.Module):
"""ResNet block."""

filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: tuple[int, int] = (1, 1)

@nn.compact
def __call__(self, x):
residual = x
y = self.conv(self.filters, (3, 3), self.strides)(x)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3))(y)
y = self.norm(scale_init=nn.initializers.zeros_init())(y)

if residual.shape != y.shape:
residual = self.conv(self.filters, (1, 1), self.strides, name="conv_proj")(
residual
)
residual = self.norm(name="norm_proj")(residual)

return self.act(residual + y)


class BottleneckResNetBlock(nn.Module):
"""Bottleneck ResNet block."""

filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: tuple[int, int] = (1, 1)

@nn.compact
def __call__(self, x):
residual = x
y = self.conv(self.filters, (1, 1))(x)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3), self.strides)(y)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters * 4, (1, 1))(y)
y = self.norm(scale_init=nn.initializers.zeros_init())(y)

if residual.shape != y.shape:
residual = self.conv(
self.filters * 4, (1, 1), self.strides, name="conv_proj"
)(residual)
residual = self.norm(name="norm_proj")(residual)

return self.act(residual + y)


class ResNet(nn.Module):
"""ResNetV1.5."""

stage_sizes: Sequence[int]
block_cls: ModuleDef
num_classes: int
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu
conv: ModuleDef = nn.Conv

@nn.compact
def __call__(self, x, train: bool = True):
conv = partial(self.conv, use_bias=False, dtype=self.dtype)
norm = partial(
nn.BatchNorm,
use_running_average=not train,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype,
axis_name="batch",
)

x = conv(
self.num_filters,
(7, 7),
(2, 2),
padding=[(3, 3), (3, 3)],
name="conv_init",
)(x)
x = norm(name="bn_init")(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(
self.num_filters * 2**i,
strides=strides,
conv=conv,
norm=norm,
act=self.act,
)(x)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
x = jnp.asarray(x, self.dtype)
return x


ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)

ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)

ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock)

ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock)

ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock)

ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=list, default=[1, 4, 32, 128])
parser.add_argument("--model-size", type=int, default=50)
args = parser.parse_args()

if args.model_size == 18:
model = ResNet18
elif args.model_size == 34:
model = ResNet34
elif args.model_size == 50:
model = ResNet50
elif args.model_size == 101:
model = ResNet101
elif args.model_size == 152:
model = ResNet152
elif args.model_size == 200:
model = ResNet200

model = model(num_classes=1000)

timings = dict()

for b in args.batch_size:
print(f"batch_size={b}")

x = jnp.ones((b, 224, 224, 3), jnp.float32)
params = model.init(random.PRNGKey(0), x, train=False)
param_count = sum(x.size for x in jax.tree.leaves(params))

print(f"Param count: {param_count}")

apply_fn_compiled = (
jax.jit(partial(model.apply, train=False)).lower(params, x).compile()
)

best_timing = np.inf
for i in range(100):
t1 = time.time()
apply_fn_compiled(params, x).block_until_ready()
t2 = time.time()
best_timing = min(best_timing, t2 - t1)

timings[b] = best_timing
print(f"Best timing: {best_timing:.5f} s")

print(timings)
37 changes: 37 additions & 0 deletions perf/resnet/reactant.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using Comonicon, BenchmarkTools
using Lux, Enzyme, Reactant, Random

Reactant.set_default_backend("gpu")

include("resnet.jl")

Comonicon.@main function main(;
optimize::String="all", batch_size::Vector{Int}=[1, 4, 32, 128],
model_size::Int=50
)
dev = reactant_device(; force=true)

model = ResNet(model_size)
ps, st = Lux.setup(Random.default_rng(), model) |> dev

println("Param count: $(Lux.parameterlength(ps))")
println("State count: $(Lux.statelength(st))")

timings = Dict{Int, Float64}()

for b in batch_size
println("batch_size=$b")

x = rand(Float32, 224, 224, 3, b) |> dev

model_compiled = Reactant.compile(
model, (x, ps, Lux.testmode(st)); sync=true, optimize=Symbol(optimize)
)

timings[b] = @belapsed $(model_compiled)($(x), $(ps), $(Lux.testmode(st)))

println("Best timing: $(timings[b]) s")
end

println(timings)
end
Loading
Loading