-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
39 lines (36 loc) · 1.33 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn as nn
from utils import get_device
### Data
IMG_SIZE = 256
X_MEAN = (0.5, 0.5, 0.5)
X_STD = (0.5, 0.5, 0.5)
Y_MEAN = (0.5, 0.5, 0.5)
Y_STD = (0.5, 0.5, 0.5)
FIXED_PAIRS = False
SCALE = (0.8, 1)
### Objective
# "For $\mathcal{L}_{GAN}$, we replace the negative log likelihood objective by a least-squares loss. This loss
# is more stable during training and generates higher quality results."
GAN_CRIT = nn.MSELoss()
CYCLE_CRIT = nn.L1Loss()
ID_CRIT = nn.L1Loss()
### Training
SEED = 124
DEVICE = get_device()
TRAIN_BATCH_SIZE = 1 # "We use the Adam solver with a batch size of 1."
LR = 0.0002 # "We train our networks from scratch, with a learning rate of 0.0002."
BETA1 = 0.5
BETA2 = 0.999
CYCLE_LAMB = 10 # "We set $\lambda = 10$."
# "The weight for the identity mapping loss was $0.5\lambda$ where $\lambda$ was the weight for cycle consistency
# loss."
ID_LAMB = 0.5 * CYCLE_LAMB
# "To reduce model oscillation we update the discriminators using a history of generated images rather than the
# ones produced by the latest generators. We keep an image buffer that stores the 50 previously created images."
BUFFER_SIZE = 50
# "We keep the same learning rate for the first 100 epochs and linearly decay the rate
# to zero over the next 100 epochs."
N_EPOCHS_BEFORE_DECAY = 100
N_EPOCHS = 200
SAVE_GENS_EVERY = 10
GEN_SAMPLES_EVERY = 4