Skip to content

Commit

Permalink
save changes to github ,strengthen visuality and debug about test() w…
Browse files Browse the repository at this point in the history
…ithout open model.eval()
  • Loading branch information
RForestLiu committed Mar 9, 2020
1 parent 5bae999 commit 2233167
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 36 deletions.
15 changes: 15 additions & 0 deletions experiments/run_fabric.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

# Run fabric experiment for each individual dataset.
# For each anomalous digit

for i in {1024,2048}
do
for m in {3..5}
do
echo "Running Fabric ###############"
echo "Manual Seed: $m ###############"
python train.py --dataset fabric --isize 128 --nc 3 --niter 100 --batchsize 32 --nz $i --manualseed $m --display --strengthen --lr 0.00005
done
done
exit 0
14 changes: 14 additions & 0 deletions experiments/run_mnist_s.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

# Run MNIST experiment for each individual dataset.
# For each anomalous digit
for i in {0..9}
do
for m in {0..2}
do
echo "#Manual Seed: $m"
echo "#Running MNIST2, Abnormal Digit: $i"
python train.py --dataset mnist --isize 32 --nc 1 --niter 15 --abnormal_class $i --manualseed $m --proportion 0.2 --lr 0.002 --display --strengthen --beta1 0.5
done
done
exit 0
15 changes: 15 additions & 0 deletions lib/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,18 @@ def l2_loss(input, target, size_average=True):
return torch.mean(torch.pow((input-target), 2))
else:
return torch.pow((input-target), 2)

def l3_loss(input, target, size_average=True):
""" L3 Loss without reduce flag.
Args:
input (FloatTensor): Input tensor
target (FloatTensor): Output tensor
Returns:
[FloatTensor]: L3 distance between input and output
"""
if size_average:
return torch.mean(torch.pow((input-target), 3))
else:
return torch.pow((input-target), 3)
117 changes: 88 additions & 29 deletions lib/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

from lib.networks import NetG, NetD, weights_init
from lib.visualizer import Visualizer
from lib.loss import l2_loss
from lib.loss import l1_loss, l2_loss, l3_loss
from lib.evaluate import evaluate


class BaseModel():
""" Base Model for ganomaly
"""

def __init__(self, opt, dataloader):
##
# Seed for deterministic behavior
Expand All @@ -38,7 +39,7 @@ def __init__(self, opt, dataloader):
self.device = torch.device("cuda:0" if self.opt.device != 'cpu' else "cpu")

##
def set_input(self, input:torch.Tensor):
def set_input(self, input: torch.Tensor):
""" Set input and ground truth
Args:
Expand All @@ -52,6 +53,9 @@ def set_input(self, input:torch.Tensor):
# Copy the first batch as the fixed input.
if self.total_steps == self.opt.batchsize:
self.fixed_input.resize_(input[0].size()).copy_(input[0])
self.visualizer.save_fixed_real_s(self.fixed_input)



##
def seed(self, seed_value):
Expand Down Expand Up @@ -100,8 +104,16 @@ def get_current_images(self):
reals = self.input.data
fakes = self.fake.data
fixed = self.netg(self.fixed_input)[0].data
# point
fixed_reals = self.fixed_input.data
# point
return reals, fakes, fixed, fixed_reals

return reals, fakes, fixed
##point
def get_low_scores_images(self):
"""
"""
return

##
def save_weights(self, epoch):
Expand All @@ -113,10 +125,15 @@ def save_weights(self, epoch):

weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights')
if not os.path.exists(weight_dir): os.makedirs(weight_dir)

torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()},
if self.opt.strengthen:
torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()},
'%s/netG%d.pth' % (weight_dir,self.opt.nz))
torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()},
'%s/netD%d.pth' % (weight_dir, self.opt.nz))
else:
torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()},
'%s/netG.pth' % (weight_dir))
torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()},
torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()},
'%s/netD.pth' % (weight_dir))

##
Expand All @@ -125,6 +142,8 @@ def train_one_epoch(self):
"""

self.netg.train()
if self.opt.strengthen:
self.netd.train() ## point
epoch_iter = 0
for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])):
self.total_steps += self.opt.batchsize
Expand All @@ -141,12 +160,13 @@ def train_one_epoch(self):
self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors)

if self.total_steps % self.opt.save_image_freq == 0:
reals, fakes, fixed = self.get_current_images()
# point
reals, fakes, fixed, fixed_reals = self.get_current_images()
self.visualizer.save_current_images(self.epoch, reals, fakes, fixed)
if self.opt.display:
self.visualizer.display_current_images(reals, fakes, fixed)
self.visualizer.display_current_images(reals, fakes, fixed, fixed_reals)

print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch+1, self.opt.niter))
print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch + 1, self.opt.niter))
# self.visualizer.print_current_errors(self.epoch, errors)

##
Expand Down Expand Up @@ -181,7 +201,11 @@ def test(self):
Raises:
IOError: Model weights not found.
"""

if self.opt.strengthen:
self.netg.eval()
with torch.no_grad():

# Load the weights of netg and netd.
if self.opt.load_weights:
path = "./output/{}/{}/train/weights/netG.pth".format(self.name.lower(), self.opt.dataset)
Expand All @@ -196,12 +220,21 @@ def test(self):
self.opt.phase = 'test'

# Create big error tensor for the test set.
self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32, device=self.device)
self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long, device=self.device)
self.latent_i = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device)
self.latent_o = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device)
self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32,
device=self.device)
self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long,
device=self.device)
self.latent_i = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32,
device=self.device)
self.latent_o = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32,
device=self.device)
self.last_feature = torch.zeros(size=(
len(self.dataloader['test'].dataset),
list(self.netd.children())[0][-3].out_channels,
list(self.netd.children())[0][-3].kernel_size[0],
list(self.netd.children())[0][-3].kernel_size[1]
), dtype=torch.float32, device=self.device)

# print(" Testing model %s." % self.name)
self.times = []
self.total_steps = 0
epoch_iter = 0
Expand All @@ -211,14 +244,24 @@ def test(self):
time_i = time.time()
self.set_input(data)
self.fake, latent_i, latent_o = self.netg(self.input)
_, features = self.netd(self.input)

error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1)
error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
time_o = time.time()

self.an_scores[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = error.reshape(error.size(0))
self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = self.gt.reshape(error.size(0))
self.latent_i [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_i.reshape(error.size(0), self.opt.nz)
self.latent_o [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_o.reshape(error.size(0), self.opt.nz)
self.an_scores[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0)] = error.reshape(
error.size(0))
self.gt_labels[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0)] = self.gt.reshape(
error.size(0))
self.latent_i[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0), :] = latent_i.reshape(
error.size(0), self.opt.nz)
self.latent_o[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0), :] = latent_o.reshape(
error.size(0), self.opt.nz)
self.last_feature[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0), :] = features.reshape(
error.size(0),
list(self.netd.children())[0][-3].out_channels,
list(self.netd.children())[0][-3].kernel_size[0],
list(self.netd.children())[0][-3].kernel_size[1])

self.times.append(time_o - time_i)

Expand All @@ -227,32 +270,44 @@ def test(self):
dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images')
if not os.path.isdir(dst):
os.makedirs(dst)
real, fake, _ = self.get_current_images()
vutils.save_image(real, '%s/real_%03d.eps' % (dst, i+1), normalize=True)
vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i+1), normalize=True)
real, fake, _, _ = self.get_current_images() #point add attribute fixed_real
vutils.save_image(real, '%s/real_%03d.eps' % (dst, i + 1), normalize=True)
vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i + 1), normalize=True)




# Measure inference time.
self.times = np.array(self.times)
self.times = np.mean(self.times[:100] * 1000)

# Scale error vector between [0, 1]
self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (torch.max(self.an_scores) - torch.min(self.an_scores))
self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (
torch.max(self.an_scores) - torch.min(self.an_scores))

# auc, eer = roc(self.gt_labels, self.an_scores)
auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric)
performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)])

if self.opt.strengthen and self.opt.phase == 'test':
self.visualizer.display_scores_histo(self.epoch, self.an_scores, self.gt_labels)
self.visualizer.display_feature(self.last_feature, self.gt_labels)

if self.opt.display_id > 0 and self.opt.phase == 'test':
counter_ratio = float(epoch_iter) / len(self.dataloader['test'].dataset)
self.visualizer.plot_performance(self.epoch, counter_ratio, performance)

return performance


##
class Ganomaly(BaseModel):
"""GANomaly Class
"""

@property
def name(self): return 'Ganomaly'
def name(self):
return 'Ganomaly'

def __init__(self, opt, dataloader):
super(Ganomaly, self).__init__(opt, dataloader)
Expand Down Expand Up @@ -284,11 +339,13 @@ def __init__(self, opt, dataloader):

##
# Initialize input tensors.
self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device)
self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32,
device=self.device)
self.label = torch.empty(size=(self.opt.batchsize,), dtype=torch.float32, device=self.device)
self.gt = torch.empty(size=(opt.batchsize,), dtype=torch.long, device=self.device)
self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device)
self.real_label = torch.ones (size=(self.opt.batchsize,), dtype=torch.float32, device=self.device)
self.gt = torch.empty(size=(opt.batchsize,), dtype=torch.long, device=self.device)
self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize),
dtype=torch.float32, device=self.device)
self.real_label = torch.ones(size=(self.opt.batchsize,), dtype=torch.float32, device=self.device)
self.fake_label = torch.zeros(size=(self.opt.batchsize,), dtype=torch.float32, device=self.device)
##
# Setup optimizer
Expand Down Expand Up @@ -340,7 +397,7 @@ def reinit_d(self):
""" Re-initialize the weights of netD
"""
self.netd.apply(weights_init)
print(' Reloading net d')
if(self.opt.strengthen != 1): print(' Reloading net d')

def optimize_params(self):
""" Forwardpass, Loss Computation and Backwardpass.
Expand All @@ -360,3 +417,5 @@ def optimize_params(self):
self.backward_d()
self.optimizer_d.step()
if self.err_d.item() < 1e-5: self.reinit_d()


3 changes: 2 additions & 1 deletion lib/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0):
cngf = cngf // 2
csize = csize * 2


# Extra layers
for t in range(n_extra_layers):
main.add_module('extra-layers-{0}-{1}-conv'.format(t, cngf),
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(self, opt):

self.features = nn.Sequential(*layers[:-1])
self.classifier = nn.Sequential(layers[-1])
#self.classifier.add_module('Tanh', nn.Tanh())
self.classifier.add_module('Sigmoid', nn.Sigmoid())

def forward(self, x):
Expand All @@ -156,7 +158,6 @@ def forward(self, x):
classifier = classifier.view(-1, 1).squeeze(1)

return classifier, features

##
class NetG(nn.Module):
"""
Expand Down
Loading

0 comments on commit 2233167

Please sign in to comment.