From 5a390e1f0d3719f999e073336fb6ac2c789591c6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 1 Nov 2024 21:13:08 -0400 Subject: [PATCH] perf: add comparisons to JAX --- .gitignore | 8 ++ perf/Project.toml | 9 +++ perf/README.md | 14 ++++ perf/pyproject.toml | 19 +++++ perf/resnet/main.jl | 59 +++++++++++++++ perf/resnet/main.py | 178 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 287 insertions(+) create mode 100644 perf/Project.toml create mode 100644 perf/README.md create mode 100644 perf/pyproject.toml create mode 100644 perf/resnet/main.jl create mode 100644 perf/resnet/main.py diff --git a/.gitignore b/.gitignore index bc5d2c437..ae8f184c7 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,11 @@ benchmarks/results # Generated by tutorials pinn_nested_ad.gif + +# poetry +poetry.lock +*.egg-info +*.pyc +.venv +.python-version +__pycache__ diff --git a/perf/Project.toml b/perf/Project.toml new file mode 100644 index 000000000..70fbf5734 --- /dev/null +++ b/perf/Project.toml @@ -0,0 +1,9 @@ +[deps] +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" diff --git a/perf/README.md b/perf/README.md new file mode 100644 index 000000000..ef3c59945 --- /dev/null +++ b/perf/README.md @@ -0,0 +1,14 @@ +# Performance Benchmarks + +## ResNet + +### ResNet50 (Forward Pass) + +Benchmark was run on a single NVIDIA RTX 4050 GPU with 6GB of memory. + +| Batch Size | Best Timing (Flax) | Best Timing (Lux + Reactant) | +| ---------- | ------------------ | ---------------------------- | +| 1 | 0.00403 s | 0.00057587 s | +| 4 | 0.00788 s | 0.000712372 s | +| 32 | 0.05146 s | 0.000810471 s | +| 128 | 0.20071 s | 0.009914158 s | diff --git a/perf/pyproject.toml b/perf/pyproject.toml new file mode 100644 index 000000000..de592e80e --- /dev/null +++ b/perf/pyproject.toml @@ -0,0 +1,19 @@ +[tool.poetry] +description = "Benchmarking Lux against Python frameworks" +authors = ["Avik Pal "] +readme = "../README.md" +package-mode = false + +[tool.poetry.dependencies] +python = "~3.10" +jax = {extras = ["cuda12"], version = "^0.4.35"} +flax = "~0.10.1" +einops = "^0.8.0" + + +[tool.poetry.group.dev.dependencies] +ipython = "^8.29.0" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/perf/resnet/main.jl b/perf/resnet/main.jl new file mode 100644 index 000000000..47cec54a2 --- /dev/null +++ b/perf/resnet/main.jl @@ -0,0 +1,59 @@ +using ArgParse, BenchmarkTools +import Metalhead +using Lux, Enzyme, Reactant, Random, Boltz + +Reactant.set_default_backend("gpu") + +function parse_commandline() + s = ArgParseSettings() + + #! format: off + @add_arg_table! s begin + "--batch-size" + help = "Batch size" + arg_type = Vector{Int} + default = [1, 4, 32, 128] + + "--model-size" + help = "Model size" + arg_type = Int + default = 50 + end + #! format: on + + return parse_args(s) +end + +function main() + parsed_args = parse_commandline() + dev = xla_device() + + model = Vision.ResNet(parsed_args["model-size"]) + ps, st = Lux.setup(Random.default_rng(), model) + ps_ra = ps |> dev + st_ra = Lux.testmode(st) |> dev + + println("Param count: $(Lux.parameterlength(ps_ra))") + println("State count: $(Lux.statelength(st_ra))") + + timings = Dict{Int, Float64}() + + for b in parsed_args["batch-size"] + println("batch_size=$b") + + x = rand(Float32, 224, 224, 3, b) |> dev + + model_compiled = @compile model(x, ps_ra, st_ra) + + timings[b] = @belapsed begin + y, _ = $(model_compiled)($(x), $(ps_ra), $(st_ra)) + Reactant.synchronize(y) + end + + println("Best timing: $(timings[b]) s") + end + + println(timings) +end + +main() diff --git a/perf/resnet/main.py b/perf/resnet/main.py new file mode 100644 index 000000000..c43085c26 --- /dev/null +++ b/perf/resnet/main.py @@ -0,0 +1,178 @@ +import argparse +import time +from functools import partial +from typing import Any, Tuple +from collections.abc import Callable, Sequence + +import flax.linen as nn +import jax +import jax.numpy as jnp +import jax.random as random +import numpy as np + +ModuleDef = Any + + +class ResNetBlock(nn.Module): + """ResNet block.""" + + filters: int + conv: ModuleDef + norm: ModuleDef + act: Callable + strides: tuple[int, int] = (1, 1) + + @nn.compact + def __call__(self, x): + residual = x + y = self.conv(self.filters, (3, 3), self.strides)(x) + y = self.norm()(y) + y = self.act(y) + y = self.conv(self.filters, (3, 3))(y) + y = self.norm(scale_init=nn.initializers.zeros_init())(y) + + if residual.shape != y.shape: + residual = self.conv(self.filters, (1, 1), self.strides, name="conv_proj")( + residual + ) + residual = self.norm(name="norm_proj")(residual) + + return self.act(residual + y) + + +class BottleneckResNetBlock(nn.Module): + """Bottleneck ResNet block.""" + + filters: int + conv: ModuleDef + norm: ModuleDef + act: Callable + strides: tuple[int, int] = (1, 1) + + @nn.compact + def __call__(self, x): + residual = x + y = self.conv(self.filters, (1, 1))(x) + y = self.norm()(y) + y = self.act(y) + y = self.conv(self.filters, (3, 3), self.strides)(y) + y = self.norm()(y) + y = self.act(y) + y = self.conv(self.filters * 4, (1, 1))(y) + y = self.norm(scale_init=nn.initializers.zeros_init())(y) + + if residual.shape != y.shape: + residual = self.conv( + self.filters * 4, (1, 1), self.strides, name="conv_proj" + )(residual) + residual = self.norm(name="norm_proj")(residual) + + return self.act(residual + y) + + +class ResNet(nn.Module): + """ResNetV1.5.""" + + stage_sizes: Sequence[int] + block_cls: ModuleDef + num_classes: int + num_filters: int = 64 + dtype: Any = jnp.float32 + act: Callable = nn.relu + conv: ModuleDef = nn.Conv + + @nn.compact + def __call__(self, x, train: bool = True): + conv = partial(self.conv, use_bias=False, dtype=self.dtype) + norm = partial( + nn.BatchNorm, + use_running_average=not train, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + axis_name="batch", + ) + + x = conv( + self.num_filters, + (7, 7), + (2, 2), + padding=[(3, 3), (3, 3)], + name="conv_init", + )(x) + x = norm(name="bn_init")(x) + x = nn.relu(x) + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") + for i, block_size in enumerate(self.stage_sizes): + for j in range(block_size): + strides = (2, 2) if i > 0 and j == 0 else (1, 1) + x = self.block_cls( + self.num_filters * 2**i, + strides=strides, + conv=conv, + norm=norm, + act=self.act, + )(x) + x = jnp.mean(x, axis=(1, 2)) + x = nn.Dense(self.num_classes, dtype=self.dtype)(x) + x = jnp.asarray(x, self.dtype) + return x + + +ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock) + +ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock) + +ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock) + +ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock) + +ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock) + +ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch-size", type=list, default=[1, 4, 32, 128]) + parser.add_argument("--model-size", type=int, default=50) + args = parser.parse_args() + + if args.model_size == 18: + model = ResNet18 + elif args.model_size == 34: + model = ResNet34 + elif args.model_size == 50: + model = ResNet50 + elif args.model_size == 101: + model = ResNet101 + elif args.model_size == 152: + model = ResNet152 + elif args.model_size == 200: + model = ResNet200 + + model = model(num_classes=1000) + + timings = dict() + + for b in args.batch_size: + print(f"batch_size={b}") + + x = jnp.ones((b, 224, 224, 3), jnp.float32) + params = model.init(random.PRNGKey(0), x, train=False) + param_count = sum(x.size for x in jax.tree.leaves(params)) + + print(f"Param count: {param_count}") + + apply_fn_compiled = jax.jit(partial(model.apply, train=False)) + + best_timing = np.inf + for i in range(100): + t1 = time.time() + apply_fn_compiled(params, x).block_until_ready() + t2 = time.time() + best_timing = min(best_timing, t2 - t1) + + timings[b] = best_timing + print(f"Best timing: {best_timing:.5f} s") + + print(timings)