Skip to content

Commit

Permalink
perf: try AoT compile for jax
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 5, 2024
1 parent 0537f57 commit 8d2209e
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion perf/resnet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def __call__(self, x, train: bool = True):

print(f"Param count: {param_count}")

apply_fn_compiled = jax.jit(partial(model.apply, train=False))
apply_fn_compiled = (
jax.jit(partial(model.apply, train=False)).lower(params, x).compile()
)

best_timing = np.inf
for i in range(100):
Expand Down

0 comments on commit 8d2209e

Please sign in to comment.