Skip to content

Commit

Permalink
update comments in language training example
Browse files Browse the repository at this point in the history
Co-authored-by: Huan Zhang <[email protected]>
  • Loading branch information
KaidiXu committed Feb 22, 2022
1 parent 499d023 commit 214710e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ my_input = BoundedTensor(my_input, ptb)
# Regular forward propagation using BoundedTensor works as usual.
prediction = model(my_input)
# Compute LiRPA bounds using the backward mode bound propagation (CROWN).
lb, ub = model.compute_bounds(x=(my_input,), method="backward")
lb, ub = model.compute_bounds(x=(my_input,), method="CROWN")
```

Checkout
Expand Down
2 changes: 1 addition & 1 deletion auto_LiRPA/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def init(self, x, aux=None, forward=False):
batch_size, length, dim_word = x.shape[0], x.shape[1], x.shape[2]

max_pos = 1
can_be_replaced = np.zeros((batch_size, length), dtype=np.bool)
can_be_replaced = np.zeros((batch_size, length), dtype=bool)

self._build_substitution(batch)

Expand Down
34 changes: 23 additions & 11 deletions examples/language/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""
A simple script to train certified defense LSTM or Transformer using the auto_LiRPA library.
"""

import argparse
import random
import pickle
import os
import pdb
import time
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationSynonym, CrossEntropyWrapperMultiInput
Expand Down Expand Up @@ -70,11 +69,18 @@

args = parser.parse_args()

# log writer in Tensorboard
writer = SummaryWriter(os.path.join(args.dir, 'log'), flush_secs=10)
file_handler = logging.FileHandler(os.path.join(args.dir, 'log/train.log'))
file_handler.setFormatter(logging.Formatter('%(levelname)-8s %(asctime)-12s %(message)s'))
logger.addHandler(file_handler)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

## Step 1: Prepare dataset and Initial original model as usual
data_train_all_nodes, data_train, data_dev, data_test = load_data(args.data)
if args.robust:
data_dev, data_test = clean_data(data_dev), clean_data(data_test)
Expand All @@ -86,10 +92,6 @@
logger.info('Dataset sizes: {}/{}/{}/{}'.format(
len(data_train_all_nodes), len(data_train), len(data_dev), len(data_test)))

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

dummy_embeddings = torch.zeros(1, args.max_sent_length, args.embedding_size, device=args.device)
dummy_labels = torch.zeros(1, dtype=torch.long, device=args.device)
Expand All @@ -100,14 +102,17 @@
elif args.model == 'lstm':
dummy_mask = torch.zeros(1, args.max_sent_length, device=args.device)
model = LSTM(args, data_train)

dev_batches = get_batches(data_dev, args.batch_size)
test_batches = get_batches(data_test, args.batch_size)

## Step 3: Define perturbation range, here we use synonym replacement perturbation constarint by args.budget
ptb = PerturbationSynonym(budget=args.budget)
dummy_embeddings = BoundedTensor(dummy_embeddings, ptb)

## Step 4: wrap model with auto_LiRPA
model_ori = model.model_from_embeddings
bound_opts = { 'relu': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True }
bound_opts = {'relu': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True}
if isinstance(model_ori, BoundedModule):
model_bound = model_ori
else:
Expand Down Expand Up @@ -179,9 +184,11 @@ def step(model, ptb, batch, eps=1.0, train=False):
acc = (torch.argmax(logits, dim=1) == labels).float().mean()

if robust:
# generate specifications
num_class = args.num_classes
c = torch.eye(num_class).type_as(embeddings)[labels].unsqueeze(1) - \
torch.eye(num_class).type_as(embeddings).unsqueeze(0)
# remove specifications to self
I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))
c = (c[I].view(embeddings.size(0), num_class - 1, num_class))
if args.method in ['IBP', 'IBP+backward', 'forward', 'forward+backward']:
Expand All @@ -191,11 +198,15 @@ def step(model, ptb, batch, eps=1.0, train=False):
if 1 - eps > 1e-4:
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP+backward', bound_upper=False)
ilb, iub = model_bound.compute_bounds(aux=aux, C=c, method='IBP', reuse_ibp=True)
# we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
lb = eps * ilb + (1 - eps) * lb
else:
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP')
else:
raise NotImplementedError

# Pad zero at the beginning for each example, and use fake label "0" for all examples because the margins
# have already been calculated by specifications
lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1)
fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)
loss_robust = robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
Expand Down Expand Up @@ -226,6 +237,7 @@ def train(epoch, batches, type):
assert(optimizer is not None)
train = type == 'train'
if args.robust:
# epsilon dynamically growth
eps_scheduler.set_epoch_length(len(batches))
if train:
eps_scheduler.train()
Expand Down

0 comments on commit 214710e

Please sign in to comment.