Skip to content

Commit f45a1ab

Browse files
authored
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
1 parent da7adae commit f45a1ab

17 files changed

+164
-118
lines changed

cifar/dataset.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
import math
2-
3-
import mlx.core as mx
1+
import numpy as np
42
from mlx.data.datasets import load_cifar10
53

64

75
def get_cifar10(batch_size, root=None):
86
tr = load_cifar10(root=root)
97

10-
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
11-
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
8+
mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
9+
std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
1210

1311
def normalize(x):
1412
x = x.astype("float32") / 255.0
@@ -23,6 +21,7 @@ def normalize(x):
2321
.image_random_crop("image", 32, 32)
2422
.key_transform("image", normalize)
2523
.batch(batch_size)
24+
.prefetch(4, 4)
2625
)
2726

2827
test = load_cifar10(root=root, train=False)

cifar/main.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import time
3+
from functools import partial
34

45
import mlx.core as mx
56
import mlx.nn as nn
@@ -33,19 +34,25 @@ def train_step(model, inp, tgt):
3334
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
3435
return loss, acc
3536

36-
train_step_fn = nn.value_and_grad(model, train_step)
37-
3837
losses = []
3938
accs = []
4039
samples_per_sec = []
4140

41+
state = [model.state, optimizer.state]
42+
43+
@partial(mx.compile, inputs=state, outputs=state)
44+
def step(inp, tgt):
45+
train_step_fn = nn.value_and_grad(model, train_step)
46+
(loss, acc), grads = train_step_fn(model, inp, tgt)
47+
optimizer.update(model, grads)
48+
return loss, acc
49+
4250
for batch_counter, batch in enumerate(train_iter):
4351
x = mx.array(batch["image"])
4452
y = mx.array(batch["label"])
4553
tic = time.perf_counter()
46-
(loss, acc), grads = train_step_fn(model, x, y)
47-
optimizer.update(model, grads)
48-
mx.eval(model.parameters(), optimizer.state)
54+
loss, acc = step(x, y)
55+
mx.eval(state)
4956
toc = time.perf_counter()
5057
loss = loss.item()
5158
acc = acc.item()

cifar/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
mlx>=0.0.9
1+
mlx>=0.2
22
mlx-data
3-
numpy
3+
numpy

cvae/dataset.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def normalize(x):
2323
.image_resize("image", h=img_size[0], w=img_size[1])
2424
.key_transform("image", normalize)
2525
.batch(batch_size)
26+
.prefetch(4, 4)
2627
)
2728

2829
# iterator over test set

cvae/main.py

+52-52
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import argparse
44
import time
5+
from functools import partial
56
from pathlib import Path
67

78
import dataset
89
import mlx.core as mx
910
import mlx.nn as nn
1011
import mlx.optimizers as optim
11-
import model
1212
import numpy as np
13+
import vae
1314
from mlx.utils import tree_flatten
1415
from PIL import Image
1516

@@ -53,44 +54,6 @@ def loss_fn(model, X):
5354
return recon_loss + kl_div
5455

5556

56-
def train_epoch(model, data, optimizer, epoch):
57-
loss_acc = 0.0
58-
throughput_acc = 0.0
59-
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
60-
61-
# Iterate over training batches
62-
for batch_count, batch in enumerate(data):
63-
X = mx.array(batch["image"])
64-
65-
throughput_tic = time.perf_counter()
66-
67-
# Forward pass + backward pass + update
68-
loss, grads = loss_and_grad_fn(model, X)
69-
optimizer.update(model, grads)
70-
71-
# Evaluate updated model parameters
72-
mx.eval(model.parameters(), optimizer.state)
73-
74-
throughput_toc = time.perf_counter()
75-
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
76-
loss_acc += loss.item()
77-
78-
if batch_count > 0 and (batch_count % 10 == 0):
79-
print(
80-
" | ".join(
81-
[
82-
f"Epoch {epoch:4d}",
83-
f"Loss {(loss_acc / batch_count):10.2f}",
84-
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
85-
f"Batch {batch_count:5d}",
86-
]
87-
),
88-
end="\r",
89-
)
90-
91-
return loss_acc, throughput_acc, batch_count
92-
93-
9457
def reconstruct(model, batch, out_file):
9558
# Reconstruct a single batch only
9659
images = mx.array(batch["image"])
@@ -127,10 +90,10 @@ def main(args):
12790
save_dir.mkdir(parents=True, exist_ok=True)
12891

12992
# Load the model
130-
vae = model.CVAE(args.latent_dims, img_size, args.max_filters)
131-
mx.eval(vae.parameters())
93+
model = vae.CVAE(args.latent_dims, img_size, args.max_filters)
94+
mx.eval(model.parameters())
13295

133-
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters()))
96+
num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters()))
13497
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
13598

13699
optimizer = optim.AdamW(learning_rate=args.lr)
@@ -139,20 +102,54 @@ def main(args):
139102
train_batch = next(train_iter)
140103
test_batch = next(test_iter)
141104

105+
state = [model.state, optimizer.state]
106+
107+
@partial(mx.compile, inputs=state, outputs=state)
108+
def step(X):
109+
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
110+
loss, grads = loss_and_grad_fn(model, X)
111+
optimizer.update(model, grads)
112+
return loss
113+
142114
for e in range(1, args.epochs + 1):
143115
# Reset iterators and stats at the beginning of each epoch
144116
train_iter.reset()
145-
vae.train()
117+
model.train()
146118

147119
# Train one epoch
148120
tic = time.perf_counter()
149-
loss_acc, throughput_acc, batch_count = train_epoch(
150-
vae, train_iter, optimizer, e
151-
)
121+
loss_acc = 0.0
122+
throughput_acc = 0.0
123+
124+
# Iterate over training batches
125+
for batch_count, batch in enumerate(train_iter):
126+
X = mx.array(batch["image"])
127+
throughput_tic = time.perf_counter()
128+
129+
# Forward pass + backward pass + update
130+
loss = step(X)
131+
132+
# Evaluate updated model parameters
133+
mx.eval(state)
134+
135+
throughput_toc = time.perf_counter()
136+
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
137+
loss_acc += loss.item()
138+
139+
if batch_count > 0 and (batch_count % 10 == 0):
140+
print(
141+
" | ".join(
142+
[
143+
f"Epoch {e:4d}",
144+
f"Loss {(loss_acc / batch_count):10.2f}",
145+
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
146+
f"Batch {batch_count:5d}",
147+
]
148+
),
149+
end="\r",
150+
)
152151
toc = time.perf_counter()
153152

154-
vae.eval()
155-
156153
print(
157154
" | ".join(
158155
[
@@ -163,14 +160,17 @@ def main(args):
163160
]
164161
)
165162
)
163+
164+
model.eval()
165+
166166
# Reconstruct a batch of training and test images
167-
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png")
168-
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png")
167+
reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png")
168+
reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png")
169169

170170
# Generate images
171-
generate(vae, save_dir / f"generated_{e:03d}.png")
171+
generate(model, save_dir / f"generated_{e:03d}.png")
172172

173-
vae.save_weights(str(save_dir / "weights.npz"))
173+
model.save_weights(str(save_dir / "weights.npz"))
174174

175175

176176
if __name__ == "__main__":

cvae/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mlx>=0.0.9
1+
mlx>=0.2
22
mlx-data
33
numpy
44
Pillow

cvae/model.py cvae/vae.py

File renamed without changes.

gcn/main.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import time
12
from argparse import ArgumentParser
3+
from functools import partial
24

35
import mlx.core as mx
46
import mlx.nn as nn
@@ -47,23 +49,31 @@ def main(args):
4749
mx.eval(gcn.parameters())
4850

4951
optimizer = optim.Adam(learning_rate=args.lr)
50-
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
5152

52-
best_val_loss = float("inf")
53-
cnt = 0
53+
state = [gcn.state, optimizer.state, mx.random.state]
5454

55-
# Training loop
56-
for epoch in range(args.epochs):
57-
# Loss
55+
@partial(mx.compile, inputs=state, outputs=state)
56+
def step():
57+
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
5858
(loss, y_hat), grads = loss_and_grad_fn(
5959
gcn, x, adj, y, train_mask, args.weight_decay
6060
)
6161
optimizer.update(gcn, grads)
62-
mx.eval(gcn.parameters(), optimizer.state)
62+
return loss, y_hat
63+
64+
best_val_loss = float("inf")
65+
cnt = 0
66+
67+
# Training loop
68+
for epoch in range(args.epochs):
69+
tic = time.time()
70+
loss, y_hat = step()
71+
mx.eval(state)
6372

6473
# Validation
6574
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
6675
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
76+
toc = time.time()
6777

6878
# Early stopping
6979
if val_loss < best_val_loss:
@@ -81,6 +91,7 @@ def main(args):
8191
f"Train loss: {loss.item():.3f}",
8292
f"Val loss: {val_loss.item():.3f}",
8393
f"Val acc: {val_acc.item():.2f}",
94+
f"Time: {1e3*(toc - tic):.3f} (ms)",
8495
]
8596
)
8697
)

llms/mlx_lm/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mlx
1+
mlx>=0.1
22
numpy
33
transformers>=4.37.0
44
protobuf

mnist/main.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import time
5+
from functools import partial
56

67
import mlx.core as mx
78
import mlx.nn as nn
@@ -34,10 +35,6 @@ def loss_fn(model, X, y):
3435
return nn.losses.cross_entropy(model(X), y, reduction="mean")
3536

3637

37-
def eval_fn(model, X, y):
38-
return mx.mean(mx.argmax(model(X), axis=1) == y)
39-
40-
4138
def batch_iterate(batch_size, X, y):
4239
perm = mx.array(np.random.permutation(y.size))
4340
for s in range(0, y.size, batch_size):
@@ -65,16 +62,25 @@ def main(args):
6562
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
6663
mx.eval(model.parameters())
6764

68-
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
6965
optimizer = optim.SGD(learning_rate=learning_rate)
66+
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
67+
68+
@partial(mx.compile, inputs=model.state, outputs=model.state)
69+
def step(X, y):
70+
loss, grads = loss_and_grad_fn(model, X, y)
71+
optimizer.update(model, grads)
72+
return loss
73+
74+
@partial(mx.compile, inputs=model.state)
75+
def eval_fn(X, y):
76+
return mx.mean(mx.argmax(model(X), axis=1) == y)
7077

7178
for e in range(num_epochs):
7279
tic = time.perf_counter()
7380
for X, y in batch_iterate(batch_size, train_images, train_labels):
74-
loss, grads = loss_and_grad_fn(model, X, y)
75-
optimizer.update(model, grads)
76-
mx.eval(model.parameters(), optimizer.state)
77-
accuracy = eval_fn(model, test_images, test_labels)
81+
step(X, y)
82+
mx.eval(model.state)
83+
accuracy = eval_fn(test_images, test_labels)
7884
toc = time.perf_counter()
7985
print(
8086
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"

mnist/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
mlx
2-
numpy
1+
mlx>=0.2
2+
numpy

normalizing_flow/main.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Copyright © 2023-2024 Apple Inc.
22

3+
from functools import partial
4+
35
import matplotlib.pyplot as plt
46
import mlx.core as mx
57
import mlx.nn as nn
@@ -27,18 +29,23 @@ def main(args):
2729
def loss_fn(model, x):
2830
return -mx.mean(model(x))
2931

30-
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
3132
optimizer = optim.Adam(learning_rate=args.learning_rate)
3233

33-
with trange(args.n_steps) as steps:
34-
for step in steps:
35-
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
36-
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
34+
state = [model.state, optimizer.state]
3735

38-
optimizer.update(model, grads)
39-
mx.eval(model.parameters())
36+
@partial(mx.compile, inputs=state, outputs=state)
37+
def step(x):
38+
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
39+
loss, grads = loss_and_grad_fn(model, x)
40+
optimizer.update(model, grads)
41+
return loss
4042

41-
steps.set_postfix(val=loss)
43+
with trange(args.n_steps) as steps:
44+
for it in steps:
45+
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
46+
loss = step(mx.array(x[idx]))
47+
mx.eval(state)
48+
steps.set_postfix(val=loss.item())
4249

4350
# Plot samples from trained flow
4451

0 commit comments

Comments
 (0)