From 22331676e6a65a567796d3587b6c6110bc12d6d0 Mon Sep 17 00:00:00 2001 From: RForestLiu Date: Mon, 9 Mar 2020 22:15:38 +0800 Subject: [PATCH] save changes to github ,strengthen visuality and debug about test() without open model.eval() --- experiments/run_fabric.sh | 15 +++++ experiments/run_mnist_s.sh | 14 ++++ lib/loss.py | 15 +++++ lib/model.py | 117 ++++++++++++++++++++++++-------- lib/networks.py | 3 +- lib/visualizer.py | 135 ++++++++++++++++++++++++++++++++++++- options.py | 1 + requirements.txt | 10 +-- 8 files changed, 274 insertions(+), 36 deletions(-) create mode 100644 experiments/run_fabric.sh create mode 100644 experiments/run_mnist_s.sh diff --git a/experiments/run_fabric.sh b/experiments/run_fabric.sh new file mode 100644 index 0000000..bdeda0e --- /dev/null +++ b/experiments/run_fabric.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Run fabric experiment for each individual dataset. +# For each anomalous digit + +for i in {1024,2048} +do + for m in {3..5} + do + echo "Running Fabric ###############" + echo "Manual Seed: $m ###############" + python train.py --dataset fabric --isize 128 --nc 3 --niter 100 --batchsize 32 --nz $i --manualseed $m --display --strengthen --lr 0.00005 + done +done +exit 0 diff --git a/experiments/run_mnist_s.sh b/experiments/run_mnist_s.sh new file mode 100644 index 0000000..fdc6371 --- /dev/null +++ b/experiments/run_mnist_s.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# Run MNIST experiment for each individual dataset. +# For each anomalous digit +for i in {0..9} +do + for m in {0..2} + do + echo "#Manual Seed: $m" + echo "#Running MNIST2, Abnormal Digit: $i" + python train.py --dataset mnist --isize 32 --nc 1 --niter 15 --abnormal_class $i --manualseed $m --proportion 0.2 --lr 0.002 --display --strengthen --beta1 0.5 + done +done +exit 0 diff --git a/lib/loss.py b/lib/loss.py index 23ff85b..805ecca 100644 --- a/lib/loss.py +++ b/lib/loss.py @@ -37,3 +37,18 @@ def l2_loss(input, target, size_average=True): return torch.mean(torch.pow((input-target), 2)) else: return torch.pow((input-target), 2) + +def l3_loss(input, target, size_average=True): + """ L3 Loss without reduce flag. + + Args: + input (FloatTensor): Input tensor + target (FloatTensor): Output tensor + + Returns: + [FloatTensor]: L3 distance between input and output + """ + if size_average: + return torch.mean(torch.pow((input-target), 3)) + else: + return torch.pow((input-target), 3) \ No newline at end of file diff --git a/lib/model.py b/lib/model.py index d7b2b2e..6c71470 100644 --- a/lib/model.py +++ b/lib/model.py @@ -17,13 +17,14 @@ from lib.networks import NetG, NetD, weights_init from lib.visualizer import Visualizer -from lib.loss import l2_loss +from lib.loss import l1_loss, l2_loss, l3_loss from lib.evaluate import evaluate class BaseModel(): """ Base Model for ganomaly """ + def __init__(self, opt, dataloader): ## # Seed for deterministic behavior @@ -38,7 +39,7 @@ def __init__(self, opt, dataloader): self.device = torch.device("cuda:0" if self.opt.device != 'cpu' else "cpu") ## - def set_input(self, input:torch.Tensor): + def set_input(self, input: torch.Tensor): """ Set input and ground truth Args: @@ -52,6 +53,9 @@ def set_input(self, input:torch.Tensor): # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.resize_(input[0].size()).copy_(input[0]) + self.visualizer.save_fixed_real_s(self.fixed_input) + + ## def seed(self, seed_value): @@ -100,8 +104,16 @@ def get_current_images(self): reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data + # point + fixed_reals = self.fixed_input.data + # point + return reals, fakes, fixed, fixed_reals - return reals, fakes, fixed + ##point + def get_low_scores_images(self): + """ + """ + return ## def save_weights(self, epoch): @@ -113,10 +125,15 @@ def save_weights(self, epoch): weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) - - torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()}, + if self.opt.strengthen: + torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()}, + '%s/netG%d.pth' % (weight_dir,self.opt.nz)) + torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()}, + '%s/netD%d.pth' % (weight_dir, self.opt.nz)) + else: + torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()}, '%s/netG.pth' % (weight_dir)) - torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()}, + torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()}, '%s/netD.pth' % (weight_dir)) ## @@ -125,6 +142,8 @@ def train_one_epoch(self): """ self.netg.train() + if self.opt.strengthen: + self.netd.train() ## point epoch_iter = 0 for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])): self.total_steps += self.opt.batchsize @@ -141,12 +160,13 @@ def train_one_epoch(self): self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: - reals, fakes, fixed = self.get_current_images() + # point + reals, fakes, fixed, fixed_reals = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: - self.visualizer.display_current_images(reals, fakes, fixed) + self.visualizer.display_current_images(reals, fakes, fixed, fixed_reals) - print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch+1, self.opt.niter)) + print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch + 1, self.opt.niter)) # self.visualizer.print_current_errors(self.epoch, errors) ## @@ -181,7 +201,11 @@ def test(self): Raises: IOError: Model weights not found. """ + + if self.opt.strengthen: + self.netg.eval() with torch.no_grad(): + # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format(self.name.lower(), self.opt.dataset) @@ -196,12 +220,21 @@ def test(self): self.opt.phase = 'test' # Create big error tensor for the test set. - self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32, device=self.device) - self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long, device=self.device) - self.latent_i = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) - self.latent_o = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) + self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32, + device=self.device) + self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long, + device=self.device) + self.latent_i = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, + device=self.device) + self.latent_o = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, + device=self.device) + self.last_feature = torch.zeros(size=( + len(self.dataloader['test'].dataset), + list(self.netd.children())[0][-3].out_channels, + list(self.netd.children())[0][-3].kernel_size[0], + list(self.netd.children())[0][-3].kernel_size[1] + ), dtype=torch.float32, device=self.device) - # print(" Testing model %s." % self.name) self.times = [] self.total_steps = 0 epoch_iter = 0 @@ -211,14 +244,24 @@ def test(self): time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) + _, features = self.netd(self.input) - error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1) + error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) time_o = time.time() - self.an_scores[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = error.reshape(error.size(0)) - self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = self.gt.reshape(error.size(0)) - self.latent_i [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_i.reshape(error.size(0), self.opt.nz) - self.latent_o [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_o.reshape(error.size(0), self.opt.nz) + self.an_scores[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0)] = error.reshape( + error.size(0)) + self.gt_labels[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0)] = self.gt.reshape( + error.size(0)) + self.latent_i[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0), :] = latent_i.reshape( + error.size(0), self.opt.nz) + self.latent_o[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0), :] = latent_o.reshape( + error.size(0), self.opt.nz) + self.last_feature[i * self.opt.batchsize: i * self.opt.batchsize + error.size(0), :] = features.reshape( + error.size(0), + list(self.netd.children())[0][-3].out_channels, + list(self.netd.children())[0][-3].kernel_size[0], + list(self.netd.children())[0][-3].kernel_size[1]) self.times.append(time_o - time_i) @@ -227,32 +270,44 @@ def test(self): dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) - real, fake, _ = self.get_current_images() - vutils.save_image(real, '%s/real_%03d.eps' % (dst, i+1), normalize=True) - vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i+1), normalize=True) + real, fake, _, _ = self.get_current_images() #point add attribute fixed_real + vutils.save_image(real, '%s/real_%03d.eps' % (dst, i + 1), normalize=True) + vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i + 1), normalize=True) + + + # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] - self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (torch.max(self.an_scores) - torch.min(self.an_scores)) + self.an_scores = (self.an_scores - torch.min(self.an_scores)) / ( + torch.max(self.an_scores) - torch.min(self.an_scores)) + # auc, eer = roc(self.gt_labels, self.an_scores) auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) + if self.opt.strengthen and self.opt.phase == 'test': + self.visualizer.display_scores_histo(self.epoch, self.an_scores, self.gt_labels) + self.visualizer.display_feature(self.last_feature, self.gt_labels) + if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len(self.dataloader['test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) + return performance + ## class Ganomaly(BaseModel): """GANomaly Class """ @property - def name(self): return 'Ganomaly' + def name(self): + return 'Ganomaly' def __init__(self, opt, dataloader): super(Ganomaly, self).__init__(opt, dataloader) @@ -284,11 +339,13 @@ def __init__(self, opt, dataloader): ## # Initialize input tensors. - self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) + self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, + device=self.device) self.label = torch.empty(size=(self.opt.batchsize,), dtype=torch.float32, device=self.device) - self.gt = torch.empty(size=(opt.batchsize,), dtype=torch.long, device=self.device) - self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) - self.real_label = torch.ones (size=(self.opt.batchsize,), dtype=torch.float32, device=self.device) + self.gt = torch.empty(size=(opt.batchsize,), dtype=torch.long, device=self.device) + self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), + dtype=torch.float32, device=self.device) + self.real_label = torch.ones(size=(self.opt.batchsize,), dtype=torch.float32, device=self.device) self.fake_label = torch.zeros(size=(self.opt.batchsize,), dtype=torch.float32, device=self.device) ## # Setup optimizer @@ -340,7 +397,7 @@ def reinit_d(self): """ Re-initialize the weights of netD """ self.netd.apply(weights_init) - print(' Reloading net d') + if(self.opt.strengthen != 1): print(' Reloading net d') def optimize_params(self): """ Forwardpass, Loss Computation and Backwardpass. @@ -360,3 +417,5 @@ def optimize_params(self): self.backward_d() self.optimizer_d.step() if self.err_d.item() < 1e-5: self.reinit_d() + + diff --git a/lib/networks.py b/lib/networks.py index 2416851..e59ab81 100644 --- a/lib/networks.py +++ b/lib/networks.py @@ -111,6 +111,7 @@ def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): cngf = cngf // 2 csize = csize * 2 + # Extra layers for t in range(n_extra_layers): main.add_module('extra-layers-{0}-{1}-conv'.format(t, cngf), @@ -147,6 +148,7 @@ def __init__(self, opt): self.features = nn.Sequential(*layers[:-1]) self.classifier = nn.Sequential(layers[-1]) + #self.classifier.add_module('Tanh', nn.Tanh()) self.classifier.add_module('Sigmoid', nn.Sigmoid()) def forward(self, x): @@ -156,7 +158,6 @@ def forward(self, x): classifier = classifier.view(-1, 1).squeeze(1) return classifier, features - ## class NetG(nn.Module): """ diff --git a/lib/visualizer.py b/lib/visualizer.py index 6daedfe..6797704 100644 --- a/lib/visualizer.py +++ b/lib/visualizer.py @@ -9,6 +9,9 @@ import time import numpy as np import torchvision.utils as vutils +import plotly.express as px +import plotly.figure_factory as ff + ## class Visualizer(): @@ -87,7 +90,7 @@ def plot_current_errors(self, epoch, counter_ratio, errors): 'xlabel': 'Epoch', 'ylabel': 'Loss' }, - win=4 + win=5 ) ## @@ -112,7 +115,7 @@ def plot_performance(self, epoch, counter_ratio, performance): 'xlabel': 'Epoch', 'ylabel': 'Stats' }, - win=5 + win=6 ) ## @@ -151,6 +154,7 @@ def print_current_performance(self, performance, best): with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) + ## def display_current_images(self, reals, fakes, fixed): """ Display current images. @@ -164,10 +168,60 @@ def display_current_images(self, reals, fakes, fixed): reals = self.normalize(reals.cpu().numpy()) fakes = self.normalize(fakes.cpu().numpy()) fixed = self.normalize(fixed.cpu().numpy()) + fixed_reals = self.normalize(self.fixed_input.cpu().numpy()) + + self.vis.images(reals, win=1, opts={'title': 'Reals'}) + self.vis.images(fakes, win=2, opts={'title': 'Fakes'}) + self.vis.images(fixed, win=3, opts={'title': 'Fixed'}) + self.vis.images(fixed_reals, win=4, opts={'title': 'fixed_reals'}) + + ##point + def display_current_images(self, reals, fakes, fixed, fixed_reals): + """ Display current images. + + Args: + epoch (int): Current epoch + counter_ratio (float): Ratio to plot the range between two epoch. + reals ([FloatTensor]): Real Image + fakes ([FloatTensor]): Fake Image + fixed ([FloatTensor]): Fixed Fake Image + fixed_reals ([FloatTensor]): Fixed real Image + """ + reals = self.normalize(reals.cpu().numpy()) + fakes = self.normalize(fakes.cpu().numpy()) + fixed = self.normalize(fixed.cpu().numpy()) + fixed_reals = self.normalize(fixed_reals.cpu().numpy()) + + self.vis.images(reals, win=1, opts={'title': 'Reals'}) + self.vis.images(fakes, win=2, opts={'title': 'Fakes'}) + self.vis.images(fixed, win=3, opts={'title': 'Fixed'}) + self.vis.images(fixed_reals, win=4, opts={'title': 'fixed_reals'}) + + def display_curent_images_info(self, fix_fake_label, fix_real_label): + + pass + + def display_current_images_test(self, reals, fakes, fixed, fixed_reals): + """ Display current images. + + Args: + epoch (int): Current epoch + counter_ratio (float): Ratio to plot the range between two epoch. + reals ([FloatTensor]): Real Image + fakes ([FloatTensor]): Fake Image + fixed ([FloatTensor]): Fixed Fake Image + fixed_reals ([FloatTensor]): Fixed real Image + """ + reals = self.normalize(reals.cpu().numpy()) + fakes = self.normalize(fakes.cpu().numpy()) + fixed = self.normalize(fixed.cpu().numpy()) + fixed_reals = self.normalize(fixed_reals.cpu().numpy()) self.vis.images(reals, win=1, opts={'title': 'Reals'}) self.vis.images(fakes, win=2, opts={'title': 'Fakes'}) self.vis.images(fixed, win=3, opts={'title': 'Fixed'}) + self.vis.images(fixed_reals, win=4, opts={'title': 'fixed_reals'}) + def save_current_images(self, epoch, reals, fakes, fixed): """ Save images for epoch i. @@ -181,3 +235,80 @@ def save_current_images(self, epoch, reals, fakes, fixed): vutils.save_image(reals, '%s/reals.png' % self.img_dir, normalize=True) vutils.save_image(fakes, '%s/fakes.png' % self.img_dir, normalize=True) vutils.save_image(fixed, '%s/fixed_fakes_%03d.png' %(self.img_dir, epoch+1), normalize=True) + + def save_current_images_s(self, epoch, reals, fakes, fixed, fixed_reals): + """ Save images for epoch i. + + Args: + epoch ([int]) : Current epoch + reals ([FloatTensor]): Real Image + fakes ([FloatTensor]): Fake Image + fixed ([FloatTensor]): Fixed Fake Image + fixed_reals ([FloatTensor]): Fixed Real Image + """ + vutils.save_image(reals, '%s/reals.png' % self.img_dir, normalize=True) + vutils.save_image(fakes, '%s/fakes.png' % self.img_dir, normalize=True) + vutils.save_image(fixed, '%s/fixed_fakes_%03d.png' % (self.img_dir, epoch+1), normalize=True) + vutils.save_image(fixed_reals, '%s/fixed_real.png' % (self.img_dir), normalize=True) + + def save_fixed_real_s(self, fixed_reals): + vutils.save_image(fixed_reals, '%s/fixed_reals.png' % (self.img_dir), normalize=True) + + + def display_scores_histo(self, epoch, scores, labels): + """Display Histogram of the scores for both normal and abnormal test samples + + Args + + """ + scores = scores.cpu().numpy() + labels = labels.cpu().numpy() + abn_score = [] + nor_score = [] + + for i, score in enumerate(scores, 0): + if labels[i] == 1: + abn_score.append(score) + elif labels[i] == 0: + nor_score.append(score) + + hist_data = [abn_score, nor_score] + group_labels = ['Abnormal', 'Normal'] + + fig = ff.create_distplot(hist_data, group_labels, bin_size=0.04) + self.vis.plotlyplot(fig, win=7) + + def display_feature(self, features, labels, alg='t-SNE', win=8, iter=1000): + + labelss = labels[:iter].cpu().numpy() + labels = [i+1 for i in labelss] + features = features[:iter].cpu().numpy().reshape(iter, -1) + + if alg == 't-SNE': + from sklearn.manifold import TSNE + + tsne = TSNE(n_components=3, perplexity=40, learning_rate=140, n_iter=1000) + tsne.fit_transform(features) + + self.vis.scatter(X=tsne.embedding_, Y=labels, win=win, opts={ + 'markersize': 1, + }) + + + + + +""" self.vis.histogram(X=abn_score, win=7, opts={ + 'stucked': False, + 'numbins': 50, + 'color': 'blue' + + }) + self.vis.histogram(X=[nor_score], win=8, opts={ + 'stucked': False, + 'numbins': 50, + 'opacity': 0.1, + 'color': 'red' + }) +""" + diff --git a/options.py b/options.py index f643c52..6c40e55 100644 --- a/options.py +++ b/options.py @@ -51,6 +51,7 @@ def __init__(self): self.parser.add_argument('--abnormal_class', default='car', help='Anomaly class idx for mnist and cifar datasets') self.parser.add_argument('--proportion', type=float, default=0.1, help='Proportion of anomalies in test set.') self.parser.add_argument('--metric', type=str, default='roc', help='Evaluation metric.') + self.parser.add_argument('--strengthen', action='store_true', help='Use strengthen tools.') ## # Train diff --git a/requirements.txt b/requirements.txt index 986a9bc..816ecdc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,9 +8,9 @@ idna==2.8 joblib==0.13.2 kiwisolver==1.1.0 matplotlib==3.1.0 -mkl-fft==1.0.12 -mkl-random==1.0.2 -mkl-service==2.0.2 +#mkl-fft==1.0.12 +#mkl-random==1.0.2 +#mkl-service==2.0.2 numpy==1.16.4 olefile==0.46 Pillow==6.2.0 @@ -27,9 +27,11 @@ scipy==1.3.0 six==1.12.0 torch==1.2.0 torchfile==0.1.0 -torchvision==0.4.0a0+6b959ee +#torchvision==0.4.0a0+6b959ee tornado==6.0.3 tqdm==4.33.0 urllib3==1.25.3 visdom==0.1.8.8 websocket-client==0.56.0 +sklearn +plotly