-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathdataset.py
39 lines (31 loc) · 1.05 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import mlx.core as mx
import numpy as np
from mlx.data.datasets import load_cifar10
def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root)
mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
def normalize(x):
x = x.astype("float32") / 255.0
return (x - mean) / std
group = mx.distributed.init()
tr_iter = (
tr.shuffle()
.partition_if(group.size() > 1, group.size(), group.rank())
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
.pad("image", 1, 4, 4, 0.0)
.image_random_crop("image", 32, 32)
.key_transform("image", normalize)
.batch(batch_size)
.prefetch(4, 4)
)
test = load_cifar10(root=root, train=False)
test_iter = (
test.to_stream()
.partition_if(group.size() > 1, group.size(), group.rank())
.key_transform("image", normalize)
.batch(batch_size)
)
return tr_iter, test_iter