2
2
3
3
import argparse
4
4
import time
5
+ from functools import partial
5
6
from pathlib import Path
6
7
7
8
import dataset
8
9
import mlx .core as mx
9
10
import mlx .nn as nn
10
11
import mlx .optimizers as optim
11
- import model
12
12
import numpy as np
13
+ import vae
13
14
from mlx .utils import tree_flatten
14
15
from PIL import Image
15
16
@@ -53,44 +54,6 @@ def loss_fn(model, X):
53
54
return recon_loss + kl_div
54
55
55
56
56
- def train_epoch (model , data , optimizer , epoch ):
57
- loss_acc = 0.0
58
- throughput_acc = 0.0
59
- loss_and_grad_fn = nn .value_and_grad (model , loss_fn )
60
-
61
- # Iterate over training batches
62
- for batch_count , batch in enumerate (data ):
63
- X = mx .array (batch ["image" ])
64
-
65
- throughput_tic = time .perf_counter ()
66
-
67
- # Forward pass + backward pass + update
68
- loss , grads = loss_and_grad_fn (model , X )
69
- optimizer .update (model , grads )
70
-
71
- # Evaluate updated model parameters
72
- mx .eval (model .parameters (), optimizer .state )
73
-
74
- throughput_toc = time .perf_counter ()
75
- throughput_acc += X .shape [0 ] / (throughput_toc - throughput_tic )
76
- loss_acc += loss .item ()
77
-
78
- if batch_count > 0 and (batch_count % 10 == 0 ):
79
- print (
80
- " | " .join (
81
- [
82
- f"Epoch { epoch :4d} " ,
83
- f"Loss { (loss_acc / batch_count ):10.2f} " ,
84
- f"Throughput { (throughput_acc / batch_count ):8.2f} im/s" ,
85
- f"Batch { batch_count :5d} " ,
86
- ]
87
- ),
88
- end = "\r " ,
89
- )
90
-
91
- return loss_acc , throughput_acc , batch_count
92
-
93
-
94
57
def reconstruct (model , batch , out_file ):
95
58
# Reconstruct a single batch only
96
59
images = mx .array (batch ["image" ])
@@ -127,10 +90,10 @@ def main(args):
127
90
save_dir .mkdir (parents = True , exist_ok = True )
128
91
129
92
# Load the model
130
- vae = model .CVAE (args .latent_dims , img_size , args .max_filters )
131
- mx .eval (vae .parameters ())
93
+ model = vae .CVAE (args .latent_dims , img_size , args .max_filters )
94
+ mx .eval (model .parameters ())
132
95
133
- num_params = sum (x .size for _ , x in tree_flatten (vae .trainable_parameters ()))
96
+ num_params = sum (x .size for _ , x in tree_flatten (model .trainable_parameters ()))
134
97
print ("Number of trainable params: {:0.04f} M" .format (num_params / 1e6 ))
135
98
136
99
optimizer = optim .AdamW (learning_rate = args .lr )
@@ -139,20 +102,54 @@ def main(args):
139
102
train_batch = next (train_iter )
140
103
test_batch = next (test_iter )
141
104
105
+ state = [model .state , optimizer .state ]
106
+
107
+ @partial (mx .compile , inputs = state , outputs = state )
108
+ def step (X ):
109
+ loss_and_grad_fn = nn .value_and_grad (model , loss_fn )
110
+ loss , grads = loss_and_grad_fn (model , X )
111
+ optimizer .update (model , grads )
112
+ return loss
113
+
142
114
for e in range (1 , args .epochs + 1 ):
143
115
# Reset iterators and stats at the beginning of each epoch
144
116
train_iter .reset ()
145
- vae .train ()
117
+ model .train ()
146
118
147
119
# Train one epoch
148
120
tic = time .perf_counter ()
149
- loss_acc , throughput_acc , batch_count = train_epoch (
150
- vae , train_iter , optimizer , e
151
- )
121
+ loss_acc = 0.0
122
+ throughput_acc = 0.0
123
+
124
+ # Iterate over training batches
125
+ for batch_count , batch in enumerate (train_iter ):
126
+ X = mx .array (batch ["image" ])
127
+ throughput_tic = time .perf_counter ()
128
+
129
+ # Forward pass + backward pass + update
130
+ loss = step (X )
131
+
132
+ # Evaluate updated model parameters
133
+ mx .eval (state )
134
+
135
+ throughput_toc = time .perf_counter ()
136
+ throughput_acc += X .shape [0 ] / (throughput_toc - throughput_tic )
137
+ loss_acc += loss .item ()
138
+
139
+ if batch_count > 0 and (batch_count % 10 == 0 ):
140
+ print (
141
+ " | " .join (
142
+ [
143
+ f"Epoch { e :4d} " ,
144
+ f"Loss { (loss_acc / batch_count ):10.2f} " ,
145
+ f"Throughput { (throughput_acc / batch_count ):8.2f} im/s" ,
146
+ f"Batch { batch_count :5d} " ,
147
+ ]
148
+ ),
149
+ end = "\r " ,
150
+ )
152
151
toc = time .perf_counter ()
153
152
154
- vae .eval ()
155
-
156
153
print (
157
154
" | " .join (
158
155
[
@@ -163,14 +160,17 @@ def main(args):
163
160
]
164
161
)
165
162
)
163
+
164
+ model .eval ()
165
+
166
166
# Reconstruct a batch of training and test images
167
- reconstruct (vae , train_batch , save_dir / f"train_{ e :03d} .png" )
168
- reconstruct (vae , test_batch , save_dir / f"test_{ e :03d} .png" )
167
+ reconstruct (model , train_batch , save_dir / f"train_{ e :03d} .png" )
168
+ reconstruct (model , test_batch , save_dir / f"test_{ e :03d} .png" )
169
169
170
170
# Generate images
171
- generate (vae , save_dir / f"generated_{ e :03d} .png" )
171
+ generate (model , save_dir / f"generated_{ e :03d} .png" )
172
172
173
- vae .save_weights (str (save_dir / "weights.npz" ))
173
+ model .save_weights (str (save_dir / "weights.npz" ))
174
174
175
175
176
176
if __name__ == "__main__" :
0 commit comments