From 214710e457213a03b24650c089424b8706abe01a Mon Sep 17 00:00:00 2001 From: KaidiXu Date: Tue, 22 Feb 2022 13:54:58 -0500 Subject: [PATCH] update comments in language training example Co-authored-by: Huan Zhang --- README.md | 2 +- auto_LiRPA/perturbations.py | 2 +- examples/language/train.py | 34 +++++++++++++++++++++++----------- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index fcae89e..4a4121d 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ my_input = BoundedTensor(my_input, ptb) # Regular forward propagation using BoundedTensor works as usual. prediction = model(my_input) # Compute LiRPA bounds using the backward mode bound propagation (CROWN). -lb, ub = model.compute_bounds(x=(my_input,), method="backward") +lb, ub = model.compute_bounds(x=(my_input,), method="CROWN") ``` Checkout diff --git a/auto_LiRPA/perturbations.py b/auto_LiRPA/perturbations.py index f0ddc35..c3fc386 100644 --- a/auto_LiRPA/perturbations.py +++ b/auto_LiRPA/perturbations.py @@ -432,7 +432,7 @@ def init(self, x, aux=None, forward=False): batch_size, length, dim_word = x.shape[0], x.shape[1], x.shape[2] max_pos = 1 - can_be_replaced = np.zeros((batch_size, length), dtype=np.bool) + can_be_replaced = np.zeros((batch_size, length), dtype=bool) self._build_substitution(batch) diff --git a/examples/language/train.py b/examples/language/train.py index 8c6f5b2..1d6d2a3 100644 --- a/examples/language/train.py +++ b/examples/language/train.py @@ -1,14 +1,13 @@ +""" +A simple script to train certified defense LSTM or Transformer using the auto_LiRPA library. +""" + import argparse -import random import pickle import os -import pdb -import time import logging import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F from torch.nn import CrossEntropyLoss from torch.utils.tensorboard import SummaryWriter from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationSynonym, CrossEntropyWrapperMultiInput @@ -70,11 +69,18 @@ args = parser.parse_args() +# log writer in Tensorboard writer = SummaryWriter(os.path.join(args.dir, 'log'), flush_secs=10) file_handler = logging.FileHandler(os.path.join(args.dir, 'log/train.log')) file_handler.setFormatter(logging.Formatter('%(levelname)-8s %(asctime)-12s %(message)s')) logger.addHandler(file_handler) +random.seed(args.seed) +np.random.seed(args.seed) +torch.manual_seed(args.seed) +torch.cuda.manual_seed_all(args.seed) + +## Step 1: Prepare dataset and Initial original model as usual data_train_all_nodes, data_train, data_dev, data_test = load_data(args.data) if args.robust: data_dev, data_test = clean_data(data_dev), clean_data(data_test) @@ -86,10 +92,6 @@ logger.info('Dataset sizes: {}/{}/{}/{}'.format( len(data_train_all_nodes), len(data_train), len(data_dev), len(data_test))) -random.seed(args.seed) -np.random.seed(args.seed) -torch.manual_seed(args.seed) -torch.cuda.manual_seed_all(args.seed) dummy_embeddings = torch.zeros(1, args.max_sent_length, args.embedding_size, device=args.device) dummy_labels = torch.zeros(1, dtype=torch.long, device=args.device) @@ -100,14 +102,17 @@ elif args.model == 'lstm': dummy_mask = torch.zeros(1, args.max_sent_length, device=args.device) model = LSTM(args, data_train) - + dev_batches = get_batches(data_dev, args.batch_size) test_batches = get_batches(data_test, args.batch_size) +## Step 3: Define perturbation range, here we use synonym replacement perturbation constarint by args.budget ptb = PerturbationSynonym(budget=args.budget) dummy_embeddings = BoundedTensor(dummy_embeddings, ptb) + +## Step 4: wrap model with auto_LiRPA model_ori = model.model_from_embeddings -bound_opts = { 'relu': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True } +bound_opts = {'relu': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True} if isinstance(model_ori, BoundedModule): model_bound = model_ori else: @@ -179,9 +184,11 @@ def step(model, ptb, batch, eps=1.0, train=False): acc = (torch.argmax(logits, dim=1) == labels).float().mean() if robust: + # generate specifications num_class = args.num_classes c = torch.eye(num_class).type_as(embeddings)[labels].unsqueeze(1) - \ torch.eye(num_class).type_as(embeddings).unsqueeze(0) + # remove specifications to self I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0))) c = (c[I].view(embeddings.size(0), num_class - 1, num_class)) if args.method in ['IBP', 'IBP+backward', 'forward', 'forward+backward']: @@ -191,11 +198,15 @@ def step(model, ptb, batch, eps=1.0, train=False): if 1 - eps > 1e-4: lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP+backward', bound_upper=False) ilb, iub = model_bound.compute_bounds(aux=aux, C=c, method='IBP', reuse_ibp=True) + # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020) lb = eps * ilb + (1 - eps) * lb else: lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP') else: raise NotImplementedError + + # Pad zero at the beginning for each example, and use fake label "0" for all examples because the margins + # have already been calculated by specifications lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1) fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device) loss_robust = robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels) @@ -226,6 +237,7 @@ def train(epoch, batches, type): assert(optimizer is not None) train = type == 'train' if args.robust: + # epsilon dynamically growth eps_scheduler.set_epoch_length(len(batches)) if train: eps_scheduler.train()