Skip to content

Commit

Permalink
perf: add comparisons to JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 2, 2024
1 parent 989ac15 commit 1dc9176
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ benchmarks/results

# Generated by tutorials
pinn_nested_ad.gif

# poetry
poetry.lock
*.egg-info
*.pyc
.venv
.python-version
__pycache__
9 changes: 9 additions & 0 deletions perf/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
14 changes: 14 additions & 0 deletions perf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Performance Benchmarks

## ResNet

### ResNet50 (Forward Pass)

Benchmark was run on a single NVIDIA RTX 4050 GPU with 6GB of memory.

| Batch Size | Best Timing (Flax) | Best Timing (Lux + Reactant) |
| ---------- | ------------------ | ---------------------------- |
| 1 | 0.00403 s | 0.00057587 s |
| 4 | 0.00879 s | 0.000712372 s |
| 32 | 0.05146 s | 0.000810471 s |
| 128 | 0.20071 s | 0.035191948 s |
19 changes: 19 additions & 0 deletions perf/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[tool.poetry]
description = "Benchmarking Lux against Python frameworks"
authors = ["Avik Pal <[email protected]>"]
readme = "../README.md"
package-mode = false

[tool.poetry.dependencies]
python = "~3.10"
jax = {extras = ["cuda12"], version = "^0.4.35"}
flax = "~0.10.1"
einops = "^0.8.0"


[tool.poetry.group.dev.dependencies]
ipython = "^8.29.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
59 changes: 59 additions & 0 deletions perf/resnet/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using ArgParse, BenchmarkTools
import Metalhead
using Lux, Enzyme, Reactant, Random, Boltz

Reactant.set_default_backend("gpu")

function parse_commandline()
s = ArgParseSettings()

#! format: off
@add_arg_table! s begin
"--batch-size"
help = "Batch size"
arg_type = Vector{Int}
default = [1, 4, 32, 128]

"--model-size"
help = "Model size"
arg_type = Int
default = 50
end
#! format: on

return parse_args(s)
end

function main()
parsed_args = parse_commandline()
dev = xla_device()

model = Vision.ResNet(parsed_args["model-size"])
ps, st = Lux.setup(Random.default_rng(), model)
ps_ra = ps |> dev
st_ra = Lux.testmode(st) |> dev

println("Param count: $(Lux.parameterlength(ps_ra))")
println("State count: $(Lux.statelength(st_ra))")

timings = Dict{Int, Float64}()

for b in parsed_args["batch-size"]
println("batch_size=$b")

x = rand(Float32, 224, 224, 3, b) |> dev

model_compiled = @compile model(x, ps_ra, st_ra)

timings[b] = @belapsed begin
y, _ = $(model_compiled)($(x), $(ps_ra), $(st_ra))
Reactant.synchronize(y)
end

println("Best timing: $(timings[b]) s")
end

println(timings)
end

main()
178 changes: 178 additions & 0 deletions perf/resnet/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import argparse
import time
from functools import partial
from typing import Any, Tuple
from collections.abc import Callable, Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np

ModuleDef = Any


class ResNetBlock(nn.Module):
"""ResNet block."""

filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: tuple[int, int] = (1, 1)

@nn.compact
def __call__(self, x):
residual = x
y = self.conv(self.filters, (3, 3), self.strides)(x)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3))(y)
y = self.norm(scale_init=nn.initializers.zeros_init())(y)

if residual.shape != y.shape:
residual = self.conv(self.filters, (1, 1), self.strides, name="conv_proj")(
residual
)
residual = self.norm(name="norm_proj")(residual)

return self.act(residual + y)


class BottleneckResNetBlock(nn.Module):
"""Bottleneck ResNet block."""

filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: tuple[int, int] = (1, 1)

@nn.compact
def __call__(self, x):
residual = x
y = self.conv(self.filters, (1, 1))(x)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3), self.strides)(y)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters * 4, (1, 1))(y)
y = self.norm(scale_init=nn.initializers.zeros_init())(y)

if residual.shape != y.shape:
residual = self.conv(
self.filters * 4, (1, 1), self.strides, name="conv_proj"
)(residual)
residual = self.norm(name="norm_proj")(residual)

return self.act(residual + y)


class ResNet(nn.Module):
"""ResNetV1.5."""

stage_sizes: Sequence[int]
block_cls: ModuleDef
num_classes: int
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu
conv: ModuleDef = nn.Conv

@nn.compact
def __call__(self, x, train: bool = True):
conv = partial(self.conv, use_bias=False, dtype=self.dtype)
norm = partial(
nn.BatchNorm,
use_running_average=not train,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype,
axis_name="batch",
)

x = conv(
self.num_filters,
(7, 7),
(2, 2),
padding=[(3, 3), (3, 3)],
name="conv_init",
)(x)
x = norm(name="bn_init")(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(
self.num_filters * 2**i,
strides=strides,
conv=conv,
norm=norm,
act=self.act,
)(x)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
x = jnp.asarray(x, self.dtype)
return x


ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)

ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)

ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock)

ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock)

ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock)

ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=list, default=[1, 4, 32, 128])
parser.add_argument("--model-size", type=int, default=50)
args = parser.parse_args()

if args.model_size == 18:
model = ResNet18
elif args.model_size == 34:
model = ResNet34
elif args.model_size == 50:
model = ResNet50
elif args.model_size == 101:
model = ResNet101
elif args.model_size == 152:
model = ResNet152
elif args.model_size == 200:
model = ResNet200

model = model(num_classes=1000)

timings = dict()

for b in args.batch_size:
print(f"batch_size={b}")

x = jnp.ones((b, 224, 224, 3), jnp.float32)
params = model.init(random.PRNGKey(0), x, train=False)
param_count = sum(x.size for x in jax.tree.leaves(params))

print(f"Param count: {param_count}")

apply_fn_compiled = jax.jit(partial(model.apply, train=False))

best_timing = np.inf
for i in range(25):
t1 = time.time()
apply_fn_compiled(params, x).block_until_ready()
t2 = time.time()
best_timing = min(best_timing, t2 - t1)

timings[b] = best_timing
print(f"Best timing: {best_timing:.5f} s")

print(timings)

0 comments on commit 1dc9176

Please sign in to comment.