-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_vgg16.py
34 lines (28 loc) · 936 Bytes
/
test_vgg16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from __future__ import division,print_function
import os, json
from glob import glob
import numpy as np
np.set_printoptions(precision=4, linewidth=100)
from matplotlib import pyplot as plt
import utils; reload(utils)
from utils import plots
"""Using Theano backend"""
#%matplotlib inline
path = "data/dogsvscats-redux/"
"""
As large as you can, but no larger than 64 is recommended.
If you have an older or cheaper GPU, you'll run out of memory, so will have to decrease this.
"""
batch_size=64
""" Import our class, and instantiate """
import vgg16; reload(vgg16)
from vgg16 import Vgg16
vgg = Vgg16()
"""
Grab a few images at a time for training and validation.
NB: They must be in subdirectories named based on their category
"""
batches = vgg.get_batches(path+'train', batch_size=batch_size)
val_batches = vgg.get_batches(path+'valid', batch_size=batch_size*2)
vgg.finetune(batches)
vgg.fit(batches, val_batches, nb_epoch=1)