forked from skypilot-org/skypilot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tpuvm_mnist.yaml
34 lines (29 loc) · 957 Bytes
/
tpuvm_mnist.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
name: tpuvm_mnist
resources:
accelerators: tpu-v2-8
accelerator_args:
runtime_version: tpu-vm-base
tpu_vm: True
# The setup command. Will be run under the working directory.
setup: |
git clone https://github.com/google/flax.git --branch v0.6.11
conda activate flax
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n flax python=3.8 -y
conda activate flax
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install --upgrade clu
pip install -e flax
pip install tensorflow tensorflow-datasets
fi
# The command to run. Will be run under the working directory.
run: |
conda activate flax
cd flax/examples/mnist
python3 main.py --workdir=/tmp/mnist \
--config=configs/default.py \
--config.learning_rate=0.05 \
--config.num_epochs=10