diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
deleted file mode 100644
index 9af86dc..0000000
--- a/.github/workflows/CI.yml
+++ /dev/null
@@ -1,32 +0,0 @@
-name: auto_LiRPA CI
-on: [push]
-jobs:
- Tests:
- runs-on: ubuntu-latest
- steps:
- - name: Create swap
- run: |
- sudo fallocate -l 16G /swapfile
- sudo chmod 600 /swapfile
- sudo mkswap /swapfile
- sudo swapon /swapfile
- free -h
- - name: Setup Python
- uses: actions/setup-python@v2.2.2
- with:
- python-version: 3.8
- architecture: x64
- - name: Check out repository code
- uses: actions/checkout@v2
- - name: Install auto_LiRPA
- run: python setup.py install
- - name: Install dependencies for examples
- run: |
- cd examples
- pip install -r requirements.txt
- cd ..
- - name: Run tests
- run: |
- cd tests
- python utils/download_models.py
- pytest
diff --git a/.gitignore b/.gitignore
index 4197b11..3ace7db 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,5 +4,16 @@ __pycache__
*.egg-info
dist
*.swp
+*.swo
*.log
-.trace_graph
\ No newline at end of file
+.trace_graph
+Verified_ret*.npy
+Verified-acc*.npy
+vnn-comp_*.npz
+*.tar.gz
+verifier_log_*
+*.pth
+*.pt
+.idea
+*.so
+release
diff --git a/.travis.yml b/.travis.yml
index c9418c8..8c80a27 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -3,8 +3,10 @@ python:
- "3.8"
install:
- pip install --editable .
- - cd examples
+ - cd examples
- pip install -r requirements.txt
+ - pip install git+https://github.com/KaidiXu/onnx2pytorch.git
+ - pip install onnxruntime
- cd ..
- sudo fallocate -l 16G /swapfile
- sudo chmod 600 /swapfile
@@ -14,4 +16,5 @@ install:
script:
- cd tests
- python utils/download_models.py
- - pytest
+ - pytest
+ - cd ..
\ No newline at end of file
diff --git a/README.md b/README.md
index 4a4121d..42c2200 100644
--- a/README.md
+++ b/README.md
@@ -13,16 +13,17 @@
## What's New?
-- Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) **won** [VNN-COMP 2021](https://sites.google.com/view/vnn2021) **with the highest total score**, outperforming 11 SOTA verifiers. α,β-CROWN uses the `auto_LiRPA` library as its core bound computation library.
-- Support for [custom operators](https://auto-lirpa.readthedocs.io/en/latest/custom_op.html). (01/02/2022)
+- Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) (using `auto_LiRPA` as its core library) **won** [VNN-COMP 2022](https://sites.google.com/view/vnn2022). Our library supports the large CIFAR100, TinyImageNet and ImageNet models in VNN-COMP 2022. (09/2022)
+- Implementation of **general cutting planes** ([GCP-CROWN](https://arxiv.org/pdf/2208.05740.pdf)), support of more activation functions and improved performance and scalability. (09/2022)
+- Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) **won** [VNN-COMP 2021](https://sites.google.com/view/vnn2021) **with the highest total score**, outperforming 11 SOTA verifiers. α,β-CROWN uses the `auto_LiRPA` library as its core bound computation library. (09/2021)
- [Optimized CROWN/LiRPA](https://arxiv.org/pdf/2011.13824.pdf) bound (α-CROWN) for ReLU, **sigmoid**, **tanh**, and **maxpool** activation functions, which can significantly outperform regular CROWN bounds. See [simple_verification.py](examples/vision/simple_verification.py#L59) for an example. (07/31/2021)
- Handle split constraints for ReLU neurons ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf)) for complete verifiers. (07/31/2021)
-- A memory efficient GPU implementation of backward (CROWN) bounds for
+- A memory efficient GPU implementation of backward (CROWN) bounds for
convolutional layers. (10/31/2020)
- Certified defense models for downscaled ImageNet, TinyImageNet, CIFAR-10, LSTM/Transformer. (08/20/2020)
- Adding support to **complex vision models** including DenseNet, ResNeXt and WideResNet. (06/30/2020)
-- **Loss fusion**, a technique that reduces training cost of tight LiRPA bounds
-(e.g. CROWN-IBP) to the same asympototic complexity of IBP, making LiRPA based certified
+- **Loss fusion**, a technique that reduces training cost of tight LiRPA bounds
+(e.g. CROWN-IBP) to the same asympototic complexity of IBP, making LiRPA based certified
defense scalable to large datasets (e.g., TinyImageNet, downscaled ImageNet). (06/30/2020)
- **Multi-GPU** support to scale LiRPA based training to large models and datasets. (06/30/2020)
- Initial release. (02/28/2020)
@@ -46,6 +47,7 @@ Our library supports the following algorithms:
* Backward mode LiRPA bound propagation ([CROWN](https://arxiv.org/pdf/1811.00866.pdf)/[DeepPoly](https://files.sri.inf.ethz.ch/website/papers/DeepPoly.pdf))
* Backward mode LiRPA bound propagation with optimized bounds ([α-CROWN](https://arxiv.org/pdf/2011.13824.pdf))
* Backward mode LiRPA bound propagation with split constraints ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf))
+* Generalized backward mode LiRPA bound propagation with general cutting plane constraints ([GCP-CROWN](https://arxiv.org/pdf/2208.05740.pdf))
* Forward mode LiRPA bound propagation ([Xu et al., 2020](https://arxiv.org/pdf/2002.12920))
* Forward mode LiRPA bound propagation with optimized bounds (similar to [α-CROWN](https://arxiv.org/pdf/2011.13824.pdf))
* Interval bound propagation ([IBP](https://arxiv.org/pdf/1810.12715.pdf))
@@ -89,9 +91,11 @@ user-defined ranges. We get guaranteed output ranges (bounds):
## Installation
-Python 3.7+ is required. Pytorch 1.8 (LTS) is recommended, although a newer
-version might also work. It is highly recommended to have a pre-installed PyTorch
-that matches your system and our version requirement. See [PyTorch Get Started](https://pytorch.org/get-started).
+Python 3.7+ and PyTorch 1.8+ are required.
+PyTorch 1.11 is recommended, although other recent versions might also work.
+It is highly recommended to have a pre-installed PyTorch
+that matches your system and our version requirement.
+See [PyTorch Get Started](https://pytorch.org/get-started).
Then you can install `auto_LiRPA` via:
```bash
@@ -102,6 +106,11 @@ python setup.py install
If you intend to modify this library, use `python setup.py develop` instead.
+Optionally, you may build and install native CUDA modules (CUDA toolkit required):
+```bash
+python auto_LiRPA/cuda_utils.py install
+```
+
## Quick Start
First define your computation as a `nn.Module` and wrap it using
@@ -127,7 +136,7 @@ my_input = BoundedTensor(my_input, ptb)
# Regular forward propagation using BoundedTensor works as usual.
prediction = model(my_input)
# Compute LiRPA bounds using the backward mode bound propagation (CROWN).
-lb, ub = model.compute_bounds(x=(my_input,), method="CROWN")
+lb, ub = model.compute_bounds(x=(my_input,), method="backward")
```
Checkout
@@ -142,7 +151,7 @@ obtaining gradients through autodiff. Bounds are efficiently computed on GPUs.
## More Working Examples
-We provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA`:
+We provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA`:
* [Basic Bound Computation and **Robustness Verification** of Neural Networks](doc/src/examples.md#basic-bound-computation-and-robustness-verification-of-neural-networks)
* [Basic **Certified Adversarial Defense** Training](doc/src/examples.md#basic-certified-adversarial-defense-training)
@@ -151,6 +160,10 @@ We provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA`
* [Certifiably Robust Language Classifier using **Transformers**](doc/src/examples.md#certifiably-robust-language-classifier-with-transformer-and-lstm)
* [Certified Robustness against **Model Weight Perturbations**](doc/src/examples.md#certified-robustness-against-model-weight-perturbations-and-certified-defense)
+`auto_LiRPA` has also be used in the following works:
+* [**α,β-CROWN for complete neural network verification**](https://github.com/huanzhang12/alpha-beta-CROWN)
+* [**Fast certified robust training**](https://github.com/shizhouxing/Fast-Certified-Robust-Training)
+
## Full Documentations
For more documentations, please refer to:
@@ -166,20 +179,27 @@ Please kindly cite our papers if you use the `auto_LiRPA` library. Full [BibTeX
The general LiRPA based bound propagation algorithm was originally proposed in our paper:
-* [Automatic Perturbation Analysis for Scalable Certified Robustness and Beyond](https://arxiv.org/pdf/2002.12920).
-NeurIPS 2020
+* [Automatic Perturbation Analysis for Scalable Certified Robustness and Beyond](https://arxiv.org/pdf/2002.12920).
+NeurIPS 2020
Kaidi Xu\*, Zhouxing Shi\*, Huan Zhang\*, Yihan Wang, Kai-Wei Chang, Minlie Huang, Bhavya Kailkhura, Xue Lin, Cho-Jui Hsieh (\* Equal contribution)
-The `auto_LiRPA` library is further extended to allow optimized bound (α-CROWN) and split constraints (β-CROWN):
+The `auto_LiRPA` library is further extended to allow optimized bound (α-CROWN), split constraints (β-CROWN) and general constraints (GCP-CROWN):
-* [Fast and Complete: Enabling Complete Neural Network Verification with Rapid and Massively Parallel Incomplete Verifiers](https://arxiv.org/pdf/2011.13824.pdf).
-ICLR 2021
-Kaidi Xu\*, Huan Zhang\*, Shiqi Wang, Yihan Wang, Suman Jana, Xue Lin and Cho-Jui Hsieh (\* Equal contribution)
+* [Fast and Complete: Enabling Complete Neural Network Verification with Rapid and Massively Parallel Incomplete Verifiers](https://arxiv.org/pdf/2011.13824.pdf).
+ICLR 2021.
+Kaidi Xu\*, Huan Zhang\*, Shiqi Wang, Yihan Wang, Suman Jana, Xue Lin and Cho-Jui Hsieh (\* Equal contribution).
-* [Beta-CROWN: Efficient Bound Propagation with Per-neuron Split Constraints for Complete and Incomplete Neural Network Verification](https://arxiv.org/pdf/2103.06624.pdf).
-NeurIPS 2021
-Shiqi Wang\*, Huan Zhang\*, Kaidi Xu\*, Suman Jana, Xue Lin, Cho-Jui Hsieh and Zico Kolter (\* Equal contribution)
+* [Beta-CROWN: Efficient Bound Propagation with Per-neuron Split Constraints for Complete and Incomplete Neural Network Verification](https://arxiv.org/pdf/2103.06624.pdf).
+NeurIPS 2021.
+Shiqi Wang\*, Huan Zhang\*, Kaidi Xu\*, Suman Jana, Xue Lin, Cho-Jui Hsieh and Zico Kolter (\* Equal contribution).
+* [GCP-CROWN: General Cutting Planes for Bound-Propagation-Based Neural Network Verification](https://arxiv.org/abs/2208.05740).
+Huan Zhang\*, Shiqi Wang\*, Kaidi Xu\*, Linyi Li, Bo Li, Suman Jana, Cho-Jui Hsieh and Zico Kolter (\* Equal contribution).
+
+Certified robust training using `auto_LiRPA` is improved to allow much shorter warmup and faster training:
+* [Fast Certified Robust Training with Short Warmup](https://arxiv.org/pdf/2103.17268.pdf).
+NeurIPS 2021.
+Zhouxing Shi\*, Yihan Wang\*, Huan Zhang, Jinfeng Yi and Cho-Jui Hsieh (\* Equal contribution).
## Developers and Copyright
@@ -187,12 +207,20 @@ Shiqi Wang\*, Huan Zhang\*, Kaidi Xu\*, Suman Jana, Xue Lin, Cho-Jui Hsieh and Z
|:--:|:--:| :--:| :--:| :--:|
| | | | | |
-* Kaidi Xu (xu.kaid@northeastern.edu): main developer
-* Zhouxing Shi (zshi@cs.ucla.edu): main developer
-* Huan Zhang (huan@huan-zhang.com): team lead
-* Yihan Wang (yihanwang@ucla.edu)
-* Shiqi Wang (sw3215@columbia.edu): contact for beta-CROWN
+Team lead:
+* Huan Zhang (huan@huan-zhang.com), CMU
+
+Main developers:
+* Zhouxing Shi (zshi@cs.ucla.edu), UCLA
+* Kaidi Xu (kx46@drexel.edu), Drexel University
+
+Contributors:
+* Yihan Wang (yihanwang@ucla.edu), UCLA
+* Shiqi Wang (sw3215@columbia.edu), Columbia University
+* Linyi Li (linyi2@illinois.edu), UIUC
+* Jinqi (Kathryn) Chen (jinqic@cs.cmu.edu), CMU
+* Zhuolin Yang (zhuolin5@illinois.edu), UIUC
-We thank [commits](https://github.com/KaidiXu/auto_LiRPA/commits) and [pull requests](https://github.com/KaidiXu/auto_LiRPA/pulls) from community contributors.
+We thank the[commits](https://github.com/KaidiXu/auto_LiRPA/commits) and [pull requests](https://github.com/KaidiXu/auto_LiRPA/pulls) from community contributors.
Our library is released under the BSD 3-Clause license.
diff --git a/auto_LiRPA/__init__.py b/auto_LiRPA/__init__.py
index f0b5922..3fbdb43 100644
--- a/auto_LiRPA/__init__.py
+++ b/auto_LiRPA/__init__.py
@@ -1,7 +1,8 @@
-from .bound_general import BoundedModule, BoundDataParallel
+from .bound_general import BoundedModule
+from .bound_multi_gpu import BoundDataParallel
from .bounded_tensor import BoundedTensor, BoundedParameter
from .perturbations import PerturbationLpNorm, PerturbationSynonym
from .wrapper import CrossEntropyWrapper, CrossEntropyWrapperMultiInput
from .bound_op_map import register_custom_op, unregister_custom_op
-__version__ = '0.2'
\ No newline at end of file
+__version__ = '0.3'
diff --git a/auto_LiRPA/adam_element_lr.py b/auto_LiRPA/adam_element_lr.py
deleted file mode 100644
index 8324110..0000000
--- a/auto_LiRPA/adam_element_lr.py
+++ /dev/null
@@ -1,179 +0,0 @@
-import torch
-import math
-from torch.optim.optimizer import Optimizer
-from torch import Tensor
-from typing import List, Optional
-
-
-def adam(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- max_exp_avg_sqs: List[Tensor],
- state_steps: List[int],
- *,
- amsgrad: bool,
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float,
- lr_scale: Optional[Tensor],
- batch_dim: Optional[int]):
- r"""Functional API that performs Adam algorithm computation.
- See :class:`~torch.optim.Adam` for details.
- """
-
- for i, param in enumerate(params):
-
- grad = grads[i]
- exp_avg = exp_avgs[i]
- exp_avg_sq = exp_avg_sqs[i]
- step = state_steps[i]
-
- bias_correction1 = 1 - beta1 ** step
- bias_correction2 = 1 - beta2 ** step
-
- if weight_decay != 0:
- grad = grad.add(param, alpha=weight_decay)
-
- # Decay the first and second moment running average coefficient
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
- if amsgrad:
- # Maintains the maximum of all 2nd moment running avg. till now
- torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
- # Use the max. for normalizing running avg. of gradient
- denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
- else:
- denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
-
- step_size = lr / bias_correction1
- if lr_scale is not None:
- # Per batch element learning rate scaler.
- # We must know which dimension is corresponding to batch and broadcast accordingly.
- total_dim = exp_avg.ndim
- new_shape = (1, ) * batch_dim + (lr_scale.size(0), ) + (1, ) * (total_dim - 1 - batch_dim)
- scaler = lr_scale.view(*new_shape)
- param.addcdiv_(scaler * exp_avg, denom, value=-step_size)
- else:
- param.addcdiv_(exp_avg, denom, value=-step_size)
- if lr_scale is not None:
- pass
- # print('lr scaler', lr_scale)
-
-
-class AdamElementLR(Optimizer):
- r"""Implements Adam algorithm, with the capability of setting different lr
- per batch element.
- It has been proposed in `Adam: A Method for Stochastic Optimization`_.
- The implementation of the L2 penalty follows changes proposed in
- `Decoupled Weight Decay Regularization`_.
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- amsgrad (boolean, optional): whether to use the AMSGrad variant of this
- algorithm from the paper `On the Convergence of Adam and Beyond`_
- (default: False)
- .. _Adam\: A Method for Stochastic Optimization:
- https://arxiv.org/abs/1412.6980
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
- """
-
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False):
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
- if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay, amsgrad=amsgrad)
- super(AdamElementLR, self).__init__(params, defaults)
-
- def __setstate__(self, state):
- super(Adam, self).__setstate__(state)
- for group in self.param_groups:
- group.setdefault('amsgrad', False)
-
- @torch.no_grad()
- def step(self, lr_scale=None, closure=None):
- """Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- for i, group in enumerate(self.param_groups):
- params_with_grad = []
- grads = []
- exp_avgs = []
- exp_avg_sqs = []
- max_exp_avg_sqs = []
- state_steps = []
- beta1, beta2 = group['betas']
-
- for p in group['params']:
- if p.grad is not None:
- params_with_grad.append(p)
- if p.grad.is_sparse:
- raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
- grads.append(p.grad)
-
- state = self.state[p]
- # Lazy state initialization
- if len(state) == 0:
- state['step'] = 0
- # Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
- # Exponential moving average of squared gradient values
- state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
- if group['amsgrad']:
- # Maintains max of all exp. moving avg. of sq. grad. values
- state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
-
- exp_avgs.append(state['exp_avg'])
- exp_avg_sqs.append(state['exp_avg_sq'])
-
- if group['amsgrad']:
- max_exp_avg_sqs.append(state['max_exp_avg_sq'])
-
- # update the steps for each param group update
- state['step'] += 1
- # record the step after step update
- state_steps.append(state['step'])
-
- adam(params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- amsgrad=group['amsgrad'],
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'],
- lr_scale=lr_scale[i] if lr_scale is not None else None,
- batch_dim=group['batch_dim'],
- )
- return loss
diff --git a/auto_LiRPA/backward_bound.py b/auto_LiRPA/backward_bound.py
new file mode 100644
index 0000000..411ae9e
--- /dev/null
+++ b/auto_LiRPA/backward_bound.py
@@ -0,0 +1,766 @@
+import torch
+from torch import Tensor
+from collections import deque, defaultdict
+from tqdm import tqdm
+from .patches import Patches
+from .utils import *
+from .bound_ops import *
+import warnings
+
+
+def batched_backward(
+ self, node, C, unstable_idx, batch_size, bound_lower=True,
+ bound_upper=True):
+ crown_batch_size = self.bound_opts['crown_batch_size']
+ unstable_size = get_unstable_size(unstable_idx)
+ print(f'Batched CROWN: unstable size {unstable_size}')
+ num_batches = (unstable_size + crown_batch_size - 1) // crown_batch_size
+ output_shape = node.output_shape[1:]
+ dim = int(prod(output_shape))
+ ret = []
+ for i in tqdm(range(num_batches)):
+ if isinstance(unstable_idx, tuple):
+ unstable_idx_batch = tuple(
+ u[i*crown_batch_size:(i+1)*crown_batch_size]
+ for u in unstable_idx
+ )
+ unstable_size_batch = len(unstable_idx_batch[0])
+ else:
+ unstable_idx_batch = unstable_idx[i*crown_batch_size:(i+1)*crown_batch_size]
+ unstable_size_batch = len(unstable_idx_batch)
+ if node.patches_start and node.mode == "patches":
+ assert C in ['Patches', None]
+ C_batch = Patches(shape=[
+ unstable_size_batch, batch_size, *node.output_shape[1:-2], 1, 1],
+ identity=1, unstable_idx=unstable_idx_batch,
+ output_shape=[batch_size, *node.output_shape[1:]])
+ elif isinstance(node, BoundLinear) or isinstance(node, BoundMatMul):
+ assert C in ['OneHot', None]
+ C_batch = OneHotC(
+ [batch_size, unstable_size_batch, *node.output_shape[1:]],
+ self.device, unstable_idx_batch, None)
+ else:
+ assert C in ['eye', None]
+ C_batch = torch.zeros([1, unstable_size_batch, dim], device=self.device)
+ C_batch[0, torch.arange(unstable_size_batch), unstable_idx_batch] = 1.0
+ C_batch = C_batch.expand(batch_size, -1, -1).view(
+ batch_size, unstable_size_batch, *output_shape)
+ ret.append(self.backward_general(
+ C=C_batch, node=node,
+ bound_lower=bound_lower, bound_upper=bound_upper,
+ average_A=False, need_A_only=False, unstable_idx=unstable_idx_batch,
+ verbose=False))
+ if bound_lower:
+ lb = torch.cat([item[0].view(batch_size, -1) for item in ret], dim=1)
+ else:
+ lb = None
+ if bound_upper:
+ ub = torch.cat([item[1].view(batch_size, -1) for item in ret], dim=1)
+ else:
+ ub = None
+ return lb, ub
+
+
+def backward_general(
+ self, C, node=None, bound_lower=True, bound_upper=True, average_A=False,
+ need_A_only=False, unstable_idx=None, unstable_size=0, update_mask=None, verbose=True):
+ if verbose:
+ logger.debug(f'Bound backward from {node.__class__.__name__}({node.name})')
+ if isinstance(C, str):
+ logger.debug(f' C: {C}')
+ elif C is not None:
+ logger.debug(f' C: shape {C.shape}, type {type(C)}')
+ _print_time = False
+
+ if isinstance(C, str):
+ # If C is a str, use batched CROWN. If batched CROWN is not intended to
+ # be enabled, C must be a explicitly provided non-str object for this function.
+ if need_A_only or self.return_A or average_A:
+ raise ValueError(
+ 'Batched CROWN is not compatible with '
+ f'need_A_only={need_A_only}, return_A={self.return_A}, '
+ f'average_A={average_A}')
+ node.lower, node.upper = self.batched_backward(
+ node, C, unstable_idx,
+ batch_size=self.root[0].value.shape[0],
+ bound_lower=bound_lower, bound_upper=bound_upper,
+ )
+ return node.lower, node.upper
+
+ for l in self._modules.values():
+ l.lA = l.uA = None
+ l.bounded = True
+
+ degree_out = get_degrees(node, self.backward_from)
+ all_nodes_before = list(degree_out.keys())
+
+ C, batch_size, output_dim, output_shape = preprocess_C(C, node)
+
+ node.lA = C if bound_lower else None
+ node.uA = C if bound_upper else None
+ lb = ub = torch.tensor(0., device=self.device)
+
+
+ # Save intermediate layer A matrices when required.
+ A_record = {}
+
+ queue = deque([node])
+ while len(queue) > 0:
+ l = queue.popleft() # backward from l
+ l.bounded = True
+
+ if l.name in self.root_name: continue
+
+ # if all the succeeds are done, then we can turn to this node in the
+ # next iteration.
+ for l_pre in l.inputs:
+ degree_out[l_pre.name] -= 1
+ if degree_out[l_pre.name] == 0:
+ queue.append(l_pre)
+
+ # Initially, l.lA or l.uA will be set to C for this node.
+ if l.lA is not None or l.uA is not None:
+ if verbose:
+ logger.debug(f' Bound backward to {l}'
+ f' (lA shape {l.lA.shape if l.lA is not None else None},'
+ f' uA shape {l.uA.shape if l.uA is not None else None},'
+ f' out shape {l.output_shape})')
+
+ if _print_time:
+ start_time = time.time()
+
+ if not l.perturbed:
+ if not hasattr(l, 'forward_value'):
+ self.get_forward_value(l)
+ lb, ub = add_constant_node(lb, ub, l)
+ continue
+
+ if l.zero_uA_mtx and l.zero_lA_mtx:
+ # A matrices are all zero, no need to propagate.
+ continue
+
+ if isinstance(l, BoundRelu):
+ # TODO: unify this interface.
+ A, lower_b, upper_b = l.bound_backward(
+ l.lA, l.uA, *l.inputs, start_node=node, unstable_idx=unstable_idx,
+ beta_for_intermediate_layers=self.intermediate_constr is not None)
+ elif isinstance(l, BoundOptimizableActivation):
+ # For other optimizable activation functions (TODO: unify with ReLU).
+ if node.name != self.final_node_name:
+ start_shape = node.output_shape[1:]
+ else:
+ start_shape = C.shape[0]
+ A, lower_b, upper_b = l.bound_backward(
+ l.lA, l.uA, *l.inputs, start_shape=start_shape, start_node=node)
+ else:
+ A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *l.inputs)
+ # After propagation through this node, we delete its lA, uA variables.
+ if not self.return_A and node.name != self.final_name:
+ del l.lA, l.uA
+ if _print_time:
+ time_elapsed = time.time() - start_time
+ if time_elapsed > 1e-3:
+ print(l, time_elapsed)
+ if lb.ndim > 0 and type(lower_b) == Tensor and self.conv_mode == 'patches':
+ lb, ub, lower_b, upper_b = check_patch_biases(lb, ub, lower_b, upper_b)
+ lb = lb + lower_b
+ ub = ub + upper_b
+ if self.return_A and self.needed_A_dict and node.name in self.needed_A_dict:
+ # FIXME remove [0][0] and [0][1]?
+ if len(self.needed_A_dict[node.name]) == 0 or l.name in self.needed_A_dict[node.name]:
+ A_record.update({l.name: {
+ "lA": A[0][0].transpose(0, 1).detach() if A[0][0] is not None else None,
+ "uA": A[0][1].transpose(0, 1).detach() if A[0][1] is not None else None,
+ # When not used, lb or ub is tensor(0).
+ "lbias": lb.transpose(0, 1).detach() if lb.ndim > 1 else None,
+ "ubias": ub.transpose(0, 1).detach() if ub.ndim > 1 else None,
+ "unstable_idx": unstable_idx
+ }})
+ # FIXME: solve conflict with the following case
+ self.A_dict.update({node.name: A_record})
+ if need_A_only and set(self.needed_A_dict[node.name]) == set(A_record.keys()):
+ # We have collected all A matrices we need. We can return now!
+ self.A_dict.update({node.name: A_record})
+ # Do not concretize to save time. We just need the A matrices.
+ # return A matrix as a dict: {node.name: [A_lower, A_upper]}
+ return None, None, self.A_dict
+
+
+ for i, l_pre in enumerate(l.inputs):
+ add_bound(l, l_pre, lA=A[i][0], uA=A[i][1])
+
+ if lb.ndim >= 2:
+ lb = lb.transpose(0, 1)
+ if ub.ndim >= 2:
+ ub = ub.transpose(0, 1)
+
+ if self.return_A and self.needed_A_dict and node.name in self.needed_A_dict:
+ save_A_record(
+ node, A_record, self.A_dict, self.root, self.needed_A_dict[node.name],
+ lb=lb, ub=ub, unstable_idx=unstable_idx)
+
+ # TODO merge into `concretize`
+ if self.cut_used and getattr(self, 'cut_module', None) is not None and self.cut_module.x_coeffs is not None:
+ # propagate input neuron in cut constraints
+ self.root[0].lA, self.root[0].uA = self.cut_module.input_cut(
+ node, self.root[0].lA, self.root[0].uA, self.root[0].lower.size()[1:], unstable_idx,
+ batch_mask=update_mask)
+
+ lb, ub = concretize(
+ lb, ub, node, self.root, batch_size, output_dim,
+ bound_lower, bound_upper, average_A=average_A)
+
+ # TODO merge into `concretize`
+ if self.cut_used and getattr(self, "cut_module", None) is not None and self.cut_module.cut_bias is not None:
+ # propagate cut bias in cut constraints
+ lb, ub = self.cut_module.bias_cut(node, lb, ub, unstable_idx, batch_mask=update_mask)
+ if lb is not None and ub is not None and ((lb-ub)>0).sum().item() > 0:
+ # make sure there is no bug for cut constraints propagation
+ print(f"Warning: lb is larger than ub with diff: {(lb-ub)[(lb-ub)>0].max().item()}")
+
+ lb = lb.view(batch_size, *output_shape) if bound_lower else None
+ ub = ub.view(batch_size, *output_shape) if bound_upper else None
+
+ if verbose:
+ logger.debug('')
+
+ if self.return_A:
+ return lb, ub, self.A_dict
+ else:
+ return lb, ub
+
+
+def get_unstable_size(unstable_idx):
+ if isinstance(unstable_idx, tuple):
+ return unstable_idx[0].numel()
+ else:
+ return unstable_idx.numel()
+
+
+def check_optimized_variable_sparsity(self, node):
+ alpha_sparsity = None # unknown.
+ for relu in self.relus:
+ if hasattr(relu, 'alpha_lookup_idx') and node.name in relu.alpha_lookup_idx:
+ if relu.alpha_lookup_idx[node.name] is not None:
+ # This node was created with sparse alpha
+ alpha_sparsity = True
+ else:
+ alpha_sparsity = False
+ break
+ # print(f'node {node.name} alpha sparsity {alpha_sparsity}')
+ return alpha_sparsity
+
+
+def get_sparse_C(
+ self, node, sparse_intermediate_bounds=True, ref_intermediate_lb=None,
+ ref_intermediate_ub=None):
+ sparse_conv_intermediate_bounds = self.bound_opts.get('sparse_conv_intermediate_bounds', False)
+ minimum_sparsity = self.bound_opts.get('minimum_sparsity', 0.9)
+ crown_batch_size = self.bound_opts.get('crown_batch_size', 1e9)
+ dim = int(prod(node.output_shape[1:]))
+ batch_size = self.batch_size
+
+ reduced_dim = False # Only partial neurons (unstable neurons) are bounded.
+ unstable_idx = None
+ unstable_size = np.inf
+ newC = None
+
+ alpha_is_sparse = self.check_optimized_variable_sparsity(node)
+
+ # NOTE: batched CROWN is so far only supported for some of the cases below
+
+ # FIXME: C matrix shape incorrect for BoundParams.
+ if (isinstance(node, BoundLinear) or isinstance(node, BoundMatMul)) and int(
+ os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0:
+ if sparse_intermediate_bounds:
+ # If we are doing bound refinement and reference bounds are given, we only refine unstable neurons.
+ # Also, if we are checking against LP solver we will refine all neurons and do not use this optimization.
+ # For each batch element, we find the unstable neurons.
+ unstable_idx, unstable_size = self.get_unstable_locations(
+ ref_intermediate_lb, ref_intermediate_ub)
+ if unstable_size == 0:
+ # Do nothing, no bounds will be computed.
+ reduced_dim = True
+ unstable_idx = []
+ elif unstable_size > crown_batch_size:
+ # Create C in batched CROWN
+ newC = 'OneHot'
+ reduced_dim = True
+ elif (unstable_size <= minimum_sparsity * dim and unstable_size > 0 and alpha_is_sparse is None) or alpha_is_sparse:
+ # When we already have sparse alpha for this layer, we always use sparse C. Otherwise we determine it by sparsity.
+ # Create an abstract C matrix, the unstable_idx are the non-zero elements in specifications for all batches.
+ newC = OneHotC(
+ [batch_size, unstable_size, *node.output_shape[1:]],
+ self.device, unstable_idx, None)
+ reduced_dim = True
+ else:
+ unstable_idx = None
+ del ref_intermediate_lb, ref_intermediate_ub
+ if not reduced_dim:
+ newC = eyeC([batch_size, dim, *node.output_shape[1:]], self.device)
+ elif node.patches_start and node.mode == "patches":
+ if sparse_intermediate_bounds:
+ unstable_idx, unstable_size = self.get_unstable_locations(
+ ref_intermediate_lb, ref_intermediate_ub, conv=True)
+ if unstable_size == 0:
+ # Do nothing, no bounds will be computed.
+ reduced_dim = True
+ unstable_idx = []
+ elif unstable_size > crown_batch_size:
+ # Create C in batched CROWN
+ newC = 'Patches'
+ reduced_dim = True
+ # We sum over the channel direction, so need to multiply that.
+ elif (sparse_conv_intermediate_bounds and unstable_size <= minimum_sparsity * dim and alpha_is_sparse is None) or alpha_is_sparse:
+ # When we already have sparse alpha for this layer, we always use sparse C. Otherwise we determine it by sparsity.
+ # Create an abstract C matrix, the unstable_idx are the non-zero elements in specifications for all batches.
+ # The shape of patches is [unstable_size, batch, C, H, W].
+ newC = Patches(
+ shape=[unstable_size, batch_size, *node.output_shape[1:-2], 1, 1],
+ identity=1, unstable_idx=unstable_idx,
+ output_shape=[batch_size, *node.output_shape[1:]])
+ reduced_dim = True
+ else:
+ unstable_idx = None
+ del ref_intermediate_lb, ref_intermediate_ub
+ # Here we create an Identity Patches object
+ if not reduced_dim:
+ newC = Patches(
+ None, 1, 0, [node.output_shape[1], batch_size, *node.output_shape[2:],
+ *node.output_shape[1:-2], 1, 1], 1,
+ output_shape=[batch_size, *node.output_shape[1:]])
+ elif isinstance(node, (BoundAdd, BoundSub)) and node.mode == "patches":
+ # FIXME: BoundAdd does not always have patches. Need to use a better way to determine patches mode.
+ # FIXME: We should not hardcode BoundAdd here!
+ if sparse_intermediate_bounds:
+ if crown_batch_size < 1e9:
+ warnings.warn('Batched CROWN is not supported in this case')
+ unstable_idx, unstable_size = self.get_unstable_locations(
+ ref_intermediate_lb, ref_intermediate_ub, conv=True)
+ if unstable_size == 0:
+ # Do nothing, no bounds will be computed.
+ reduced_dim = True
+ unstable_idx = []
+ elif (sparse_conv_intermediate_bounds and unstable_size <= minimum_sparsity * dim and alpha_is_sparse is None) or alpha_is_sparse:
+ # When we already have sparse alpha for this layer, we always use sparse C. Otherwise we determine it by sparsity.
+ num_channel = node.output_shape[-3]
+ # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1).
+ patches = (
+ torch.eye(num_channel, device=self.device,
+ dtype=list(self.parameters())[0].dtype)).view(
+ num_channel, 1, 1, 1, num_channel, 1, 1)
+ # Expand to (out_c, 1, unstable_size, out_c, 1, 1).
+ patches = patches.expand(-1, 1, node.output_shape[-2], node.output_shape[-1], -1, 1, 1)
+ patches = patches[unstable_idx[0], :, unstable_idx[1], unstable_idx[2]]
+ # Expand with the batch dimension. Final shape (unstable_size, batch_size, out_c, 1, 1).
+ patches = patches.expand(-1, batch_size, -1, -1, -1)
+ newC = Patches(
+ patches, 1, 0, patches.shape, unstable_idx=unstable_idx,
+ output_shape=[batch_size, *node.output_shape[1:]])
+ reduced_dim = True
+ else:
+ unstable_idx = None
+ del ref_intermediate_lb, ref_intermediate_ub
+ if not reduced_dim:
+ num_channel = node.output_shape[-3]
+ # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1).
+ patches = (
+ torch.eye(num_channel, device=self.device,
+ dtype=list(self.parameters())[0].dtype)).view(
+ num_channel, 1, 1, 1, num_channel, 1, 1)
+ # Expand to (out_c, batch, out_h, out_w, out_c, 1, 1).
+ patches = patches.expand(-1, batch_size, node.output_shape[-2], node.output_shape[-1], -1, 1, 1)
+ newC = Patches(patches, 1, 0, patches.shape, output_shape=[batch_size, *node.output_shape[1:]])
+ else:
+ if sparse_intermediate_bounds:
+ unstable_idx, unstable_size = self.get_unstable_locations(
+ ref_intermediate_lb, ref_intermediate_ub)
+ if unstable_size == 0:
+ # Do nothing, no bounds will be computed.
+ reduced_dim = True
+ unstable_idx = []
+ elif unstable_size > crown_batch_size:
+ # Create in C in batched CROWN
+ newC = 'eye'
+ reduced_dim = True
+ elif (unstable_size <= minimum_sparsity * dim and alpha_is_sparse is None) or alpha_is_sparse:
+ newC = torch.zeros([1, unstable_size, dim], device=self.device)
+ # Fill the corresponding elements to 1.0
+ newC[0, torch.arange(unstable_size), unstable_idx] = 1.0
+ newC = newC.expand(batch_size, -1, -1).view(batch_size, unstable_size, *node.output_shape[1:])
+ reduced_dim = True
+ else:
+ unstable_idx = None
+ del ref_intermediate_lb, ref_intermediate_ub
+ if not reduced_dim:
+ if dim > 1000:
+ warnings.warn(f"Creating an identity matrix with size {dim}x{dim} for node {node}. This may indicate poor performance for bound computation. If you see this message on a small network please submit a bug report.", stacklevel=2)
+ newC = torch.eye(dim, device=self.device, dtype=list(self.parameters())[0].dtype) \
+ .unsqueeze(0).expand(batch_size, -1, -1) \
+ .view(batch_size, dim, *node.output_shape[1:])
+
+ return newC, reduced_dim, unstable_idx, unstable_size
+
+
+def restore_sparse_bounds(
+ self, node, unstable_idx, unstable_size, ref_intermediate_lb, ref_intermediate_ub,
+ new_lower=None, new_upper=None):
+ batch_size = self.batch_size
+ if unstable_size == 0:
+ # No unstable neurons. Skip the update.
+ node.lower = ref_intermediate_lb.detach().clone()
+ node.upper = ref_intermediate_ub.detach().clone()
+ else:
+ if new_lower is None:
+ new_lower = node.lower
+ if new_upper is None:
+ new_upper = node.upper
+ # If we only calculated unstable neurons, we need to scatter the results back based on reference bounds.
+ if isinstance(unstable_idx, tuple):
+ lower = ref_intermediate_lb.detach().clone()
+ upper = ref_intermediate_ub.detach().clone()
+ # Conv layer with patches, the unstable_idx is a 3-element tuple for 3 indices (C, H,W) of unstable neurons.
+ if len(unstable_idx) == 3:
+ lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = new_lower
+ upper[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = new_upper
+ elif len(unstable_idx) == 4:
+ lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2], unstable_idx[3]] = new_lower
+ upper[:, unstable_idx[0], unstable_idx[1], unstable_idx[2], unstable_idx[3]] = new_upper
+ else:
+ # Other layers.
+ lower = ref_intermediate_lb.detach().clone().view(batch_size, -1)
+ upper = ref_intermediate_ub.detach().clone().view(batch_size, -1)
+ lower[:, unstable_idx] = new_lower.view(batch_size, -1)
+ upper[:, unstable_idx] = new_upper.view(batch_size, -1)
+ node.lower = lower.view(batch_size, *node.output_shape[1:])
+ node.upper = upper.view(batch_size, *node.output_shape[1:])
+
+
+def get_degrees(node_start, backward_from):
+ degrees = {}
+ queue = deque([node_start])
+ node_start.bounded = False
+ while len(queue) > 0:
+ l = queue.popleft()
+ backward_from[l.name].append(node_start)
+ for l_pre in l.inputs:
+ degrees[l_pre.name] = degrees.get(l_pre.name, 0) + 1
+ if l_pre.bounded:
+ l_pre.bounded = False
+ queue.append(l_pre)
+ return degrees
+
+
+def preprocess_C(C, node):
+ if isinstance(C, Patches):
+ if C.unstable_idx is None:
+ # Patches have size (out_c, batch, out_h, out_w, c, h, w).
+ if len(C.shape) == 7:
+ out_c, batch_size, out_h, out_w = C.shape[:4]
+ output_dim = out_c * out_h * out_w
+ else:
+ out_dim, batch_size, out_c, out_h, out_w = C.shape[:5]
+ output_dim = out_dim * out_c * out_h * out_w
+ else:
+ # Patches have size (unstable_size, batch, c, h, w).
+ output_dim, batch_size = C.shape[:2]
+ else:
+ batch_size, output_dim = C.shape[:2]
+
+ # The C matrix specified by the user has shape (batch, spec) but internally we have (spec, batch) format.
+ if not isinstance(C, (eyeC, Patches, OneHotC)):
+ C = C.transpose(0, 1)
+ elif isinstance(C, eyeC):
+ C = C._replace(shape=(C.shape[1], C.shape[0], *C.shape[2:]))
+ elif isinstance(C, OneHotC):
+ C = C._replace(
+ shape=(C.shape[1], C.shape[0], *C.shape[2:]),
+ index=C.index.transpose(0,-1),
+ coeffs=None if C.coeffs is None else C.coeffs.transpose(0,-1))
+
+ if isinstance(C, Patches) and C.unstable_idx is not None:
+ # Sparse patches; the output shape is (unstable_size, ).
+ output_shape = [C.shape[0]]
+ elif prod(node.output_shape[1:]) != output_dim and not isinstance(C, Patches):
+ # For the output node, the shape of the bound follows C
+ # instead of the original output shape
+ #
+ # TODO Maybe don't set node.lower and node.upper in this case?
+ # Currently some codes still depend on node.lower and node.upper
+ output_shape = [-1]
+ else:
+ # Generally, the shape of the bounds match the output shape of the node
+ output_shape = node.output_shape[1:]
+
+ return C, batch_size, output_dim, output_shape
+
+
+def concretize(lb, ub, node, root, batch_size, output_dim, bound_lower=True, bound_upper=True, average_A=False):
+
+ for i in range(len(root)):
+ if root[i].lA is None and root[i].uA is None: continue
+ if average_A and isinstance(root[i], BoundParams):
+ lA = root[i].lA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].lA.shape) if bound_lower else None
+ uA = root[i].uA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].uA.shape) if bound_upper else None
+ else:
+ lA, uA = root[i].lA, root[i].uA
+ if not isinstance(root[i].lA, eyeC) and not isinstance(root[i].lA, Patches):
+ lA = root[i].lA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_lower else None
+ if not isinstance(root[i].uA, eyeC) and not isinstance(root[i].uA, Patches):
+ uA = root[i].uA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_upper else None
+ if hasattr(root[i], 'perturbation') and root[i].perturbation is not None:
+ if isinstance(root[i], BoundParams):
+ # add batch_size dim for weights node
+ lb = lb + root[i].perturbation.concretize(
+ root[i].center.unsqueeze(0), lA,
+ sign=-1, aux=root[i].aux) if bound_lower else None
+ ub = ub + root[i].perturbation.concretize(
+ root[i].center.unsqueeze(0), uA,
+ sign=+1, aux=root[i].aux) if bound_upper else None
+ else:
+ lb = lb + root[i].perturbation.concretize(
+ root[i].center, lA, sign=-1, aux=root[i].aux) if bound_lower else None
+ ub = ub + root[i].perturbation.concretize(
+ root[i].center, uA, sign=+1, aux=root[i].aux) if bound_upper else None
+ else:
+ fv = root[i].forward_value
+ if type(root[i]) == BoundInput:
+ # Input node with a batch dimension
+ batch_size_ = batch_size
+ else:
+ # Parameter node without a batch dimension
+ batch_size_ = 1
+
+ if bound_lower:
+ if isinstance(lA, eyeC):
+ lb = lb + fv.view(batch_size_, -1)
+ elif isinstance(lA, Patches):
+ lb = lb + lA.matmul(fv, input_shape=root[0].center.shape)
+ elif type(root[i]) == BoundInput:
+ lb = lb + lA.matmul(fv.view(batch_size_, -1, 1)).squeeze(-1)
+ else:
+ lb = lb + lA.matmul(fv.view(-1, 1)).squeeze(-1)
+ else:
+ lb = None
+
+ if bound_upper:
+ if isinstance(uA, eyeC):
+ ub = ub + fv.view(batch_size_, -1)
+ elif isinstance(uA, Patches):
+ ub = ub + uA.matmul(fv, input_shape=root[0].center.shape)
+ elif type(root[i]) == BoundInput:
+ ub = ub + uA.matmul(fv.view(batch_size_, -1, 1)).squeeze(-1)
+ else:
+ ub = ub + uA.matmul(fv.view(-1, 1)).squeeze(-1)
+ else:
+ ub = None
+
+ return lb, ub
+
+
+def addA(A1, A2):
+ """ Add two A (each of them is either Tensor or Patches) """
+ if type(A1) == type(A2):
+ return A1 + A2
+ elif type(A1) == Patches:
+ return A1 + A2
+ elif type(A2) == Patches:
+ return A2 + A1
+ else:
+ raise NotImplementedError(f'Unsupported types for A1 ({type(A1)}) and A2 ({type(A2)}')
+
+
+def add_bound(node, node_pre, lA, uA):
+ """Propagate lA and uA to a preceding node."""
+ if lA is not None:
+ if node_pre.lA is None:
+ # First A added to this node.
+ node_pre.zero_lA_mtx = node.zero_backward_coeffs_l
+ node_pre.lA = lA
+ else:
+ node_pre.zero_lA_mtx = node_pre.zero_lA_mtx and node.zero_backward_coeffs_l
+ new_node_lA = addA(node_pre.lA, lA)
+ node_pre.lA = new_node_lA
+ if uA is not None:
+ if node_pre.uA is None:
+ # First A added to this node.
+ node_pre.zero_uA_mtx = node_pre.zero_backward_coeffs_u
+ node_pre.uA = uA
+ else:
+ node_pre.zero_uA_mtx = node_pre.zero_uA_mtx and node.zero_backward_coeffs_u
+ node_pre.uA = addA(node_pre.uA, uA)
+
+
+def get_beta_watch_list(intermediate_constr, all_nodes_before):
+ beta_watch_list = defaultdict(dict)
+ if intermediate_constr is not None:
+ # Intermediate layer betas are handled in two cases.
+ # First, if the beta split is before this node, we don't need to do anything special;
+ # it will done in BoundRelu.
+ # Second, if the beta split after this node, we need to modify the A matrix
+ # during bound propagation to reflect beta after this layer.
+ for k in intermediate_constr:
+ if k not in all_nodes_before:
+ # The second case needs special care: we add all such splits in a watch list.
+ # However, after first occurance of a layer in the watchlist,
+ # beta_watch_list will be deleted and the A matrix from split constraints
+ # has been added and will be propagated to later layers.
+ for kk, vv in intermediate_constr[k].items():
+ beta_watch_list[kk][k] = vv
+ return beta_watch_list
+
+
+def add_constant_node(lb, ub, node):
+ new_lb = node.get_bias(node.lA, node.forward_value)
+ new_ub = node.get_bias(node.uA, node.forward_value)
+ if isinstance(lb, Tensor) and isinstance(new_lb, Tensor) and lb.ndim > 0 and lb.ndim != new_lb.ndim:
+ new_lb = new_lb.reshape(lb.shape)
+ if isinstance(ub, Tensor) and isinstance(new_ub, Tensor) and ub.ndim > 0 and ub.ndim != new_ub.ndim:
+ new_ub = new_ub.reshape(ub.shape)
+ lb = lb + new_lb # FIXME (09/16): shape for the bias of BoundConstant.
+ ub = ub + new_ub
+ return lb, ub
+
+
+def save_A_record(node, A_record, A_dict, root, needed_A_dict, lb, ub, unstable_idx):
+ root_A_record = {}
+ for i in range(len(root)):
+ if root[i].lA is None and root[i].uA is None: continue
+ if root[i].name in needed_A_dict:
+ if root[i].lA is not None:
+ if isinstance(root[i].lA, Patches):
+ _lA = root[i].lA
+ else:
+ _lA = root[i].lA.transpose(0, 1).detach()
+ else:
+ _lA = None
+
+ if root[i].uA is not None:
+ if isinstance(root[i].uA, Patches):
+ _uA = root[i].uA
+ else:
+ _uA = root[i].uA.transpose(0, 1).detach()
+ else:
+ _uA = None
+ root_A_record.update({root[i].name: {
+ "lA": _lA,
+ "uA": _uA,
+ # When not used, lb or ub is tensor(0). They have been transposed above.
+ "lbias": lb.detach() if lb.ndim > 1 else None,
+ "ubias": ub.detach() if ub.ndim > 1 else None,
+ "unstable_idx": unstable_idx
+ }})
+ root_A_record.update(A_record) # merge to existing A_record
+ A_dict.update({node.name: root_A_record})
+
+
+def select_unstable_idx(ref_intermediate_lb, ref_intermediate_ub, unstable_locs, max_crown_size):
+ """When there are too many unstable neurons, only bound those
+ with the loosest reference bounds."""
+ gap = (
+ ref_intermediate_ub[:, unstable_locs]
+ - ref_intermediate_lb[:, unstable_locs]).sum(dim=0)
+ indices = torch.argsort(gap, descending=True)
+ indices_selected = indices[:max_crown_size]
+ indices_selected, _ = torch.sort(indices_selected)
+ print(f'{len(indices_selected)}/{len(indices)} unstable neurons selected for CROWN')
+ return indices_selected
+
+
+def get_unstable_locations(
+ self, ref_intermediate_lb, ref_intermediate_ub, conv=False, channel_only=False):
+ max_crown_size = self.bound_opts.get('max_crown_size', int(1e9))
+ # For conv layer we only check the case where all neurons are active/inactive.
+ unstable_masks = torch.logical_and(ref_intermediate_lb < 0, ref_intermediate_ub > 0)
+ # For simplicity, merge unstable locations for all elements in this batch. TODO: use individual unstable mask.
+ # It has shape (H, W) indicating if a neuron is unstable/stable.
+ # TODO: so far we merge over the batch dimension to allow easier implementation.
+ if channel_only:
+ # Only keep channels with unstable neurons. Used for initializing alpha.
+ unstable_locs = unstable_masks.sum(dim=(0,2,3)).bool()
+ # Shape is consistent with linear layers: a list of unstable neuron channels (no batch dim).
+ unstable_idx = unstable_locs.nonzero().squeeze(1)
+ else:
+ if not conv and unstable_masks.ndim > 2:
+ # Flatten the conv layer shape.
+ unstable_masks = unstable_masks.view(unstable_masks.size(0), -1)
+ ref_intermediate_lb = ref_intermediate_lb.view(ref_intermediate_lb.size(0), -1)
+ ref_intermediate_ub = ref_intermediate_ub.view(ref_intermediate_ub.size(0), -1)
+ unstable_locs = unstable_masks.sum(dim=0).bool()
+ if conv:
+ # Now converting it to indices for these unstable nuerons.
+ # These are locations (i,j) of unstable neurons.
+ unstable_idx = unstable_locs.nonzero(as_tuple=True)
+ else:
+ unstable_idx = unstable_locs.nonzero().squeeze(1)
+
+ unstable_size = get_unstable_size(unstable_idx)
+ if unstable_size > max_crown_size:
+ indices_seleted = select_unstable_idx(
+ ref_intermediate_lb, ref_intermediate_ub, unstable_locs, max_crown_size)
+ if isinstance(unstable_idx, tuple):
+ unstable_idx = tuple(u[indices_seleted] for u in unstable_idx)
+ else:
+ unstable_idx = unstable_idx[indices_seleted]
+ unstable_size = get_unstable_size(unstable_idx)
+
+ return unstable_idx, unstable_size
+
+
+def get_alpha_crown_start_nodes(
+ self, node, c=None, share_slopes=False, final_node_name=None):
+ # When use_full_conv_alpha is True, conv layers do not share alpha.
+ sparse_intermediate_bounds = self.bound_opts.get('sparse_intermediate_bounds', False)
+ use_full_conv_alpha_thresh = self.bound_opts.get('use_full_conv_alpha_thresh', 512)
+
+ start_nodes = []
+ for nj in self.backward_from[node.name]: # Pre-activation layers.
+ unstable_idx = None
+ use_sparse_conv = None
+ use_full_conv_alpha = self.bound_opts.get('use_full_conv_alpha', False)
+ if (sparse_intermediate_bounds and isinstance(node, BoundRelu)
+ and nj.name != final_node_name and not share_slopes):
+ # Create sparse optimization variables for intermediate neurons.
+ if ((isinstance(nj, BoundLinear) or isinstance(nj, BoundMatMul))
+ and int(os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0):
+ # unstable_idx has shape [neuron_size_of_nj]. Batch dimension is reduced.
+ unstable_idx, _ = self.get_unstable_locations(nj.lower, nj.upper)
+ elif isinstance(nj, (BoundConv, BoundAdd, BoundSub, BoundBatchNormalization)) and nj.mode == 'patches':
+ if nj.name in node.patch_size:
+ # unstable_idx has shape [channel_size_of_nj]. Batch and spatial dimensions are reduced.
+ unstable_idx, _ = self.get_unstable_locations(
+ nj.lower, nj.upper, channel_only=not use_full_conv_alpha, conv=True)
+ use_sparse_conv = False # alpha is shared among channels. Sparse-spec alpha in hw dimension not used.
+ if use_full_conv_alpha and unstable_idx[0].size(0) > use_full_conv_alpha_thresh:
+ # Too many unstable neurons. Using shared alpha per channel.
+ unstable_idx, _ = self.get_unstable_locations(
+ nj.lower, nj.upper, channel_only=True, conv=True)
+ use_full_conv_alpha = False
+ else:
+ # matrix mode for conv layers.
+ # unstable_idx has shape [c_out * h_out * w_out]. Batch dimension is reduced.
+ unstable_idx, _ = self.get_unstable_locations(nj.lower, nj.upper)
+ use_sparse_conv = True # alpha is not shared among channels, and is sparse in spec dimension.
+ if nj.name == final_node_name:
+ size_final = self[final_node_name].output_shape[-1] if c is None else c.size(1)
+ start_nodes.append((final_node_name, size_final, None))
+ continue
+ if share_slopes:
+ # all intermediate neurons from the same layer share the same set of slopes.
+ output_shape = 1
+ elif isinstance(node, BoundOptimizableActivation) and node.patch_size and nj.name in node.patch_size:
+ # Patches mode. Use output channel size as the spec size. This still shares some alpha, but better than no sharing.
+ if use_full_conv_alpha:
+ # alphas not shared among channels, so the spec dim shape is c,h,w
+ # The patch size is [out_ch, batch, out_h, out_w, in_ch, H, W]. We use out_ch as the output shape.
+ output_shape = node.patch_size[nj.name][0], node.patch_size[nj.name][2], node.patch_size[nj.name][3]
+ else:
+ # The spec dim is c only, and is shared among h, w.
+ output_shape = node.patch_size[nj.name][0]
+ assert not sparse_intermediate_bounds or use_sparse_conv is False # Double check our assumption holds. If this fails, then we created wrong shapes for alpha.
+ else:
+ # Output is linear layer, or patch converted to matrix.
+ assert not sparse_intermediate_bounds or use_sparse_conv is not False # Double check our assumption holds. If this fails, then we created wrong shapes for alpha.
+ output_shape = nj.lower.shape[1:] # FIXME: for non-relu activations it's still expecting a prod.
+ start_nodes.append((nj.name, output_shape, unstable_idx))
+ return start_nodes
diff --git a/auto_LiRPA/beta_crown.py b/auto_LiRPA/beta_crown.py
new file mode 100644
index 0000000..6be6dc0
--- /dev/null
+++ b/auto_LiRPA/beta_crown.py
@@ -0,0 +1,48 @@
+import torch
+
+
+def beta_bias(self):
+ batch_size = len(self.relus[-1].split_beta)
+ batch = int(batch_size/2)
+ bias = torch.zeros((batch_size, 1), device=self.device)
+ for m in self.relus:
+ if not m.used or not m.perturbed:
+ continue
+ if m.split_beta_used:
+ bias[:batch] = bias[:batch] + m.split_bias*m.split_beta[:batch]*m.split_c[:batch]
+ bias[batch:] = bias[batch:] + m.split_bias*m.split_beta[batch:]*m.split_c[batch:]
+ if m.history_beta_used:
+ bias = bias + (m.new_history_bias*m.new_history_beta*m.new_history_c).sum(1, keepdim=True)
+ # No single node split here, because single node splits do not have bias.
+ return bias
+
+
+def print_optimized_beta(self, relus, intermediate_beta_enabled=False):
+ masked_betas = []
+ for model in relus:
+ masked_betas.append(model.masked_beta)
+ if model.history_beta_used:
+ print(f"{model.name} history beta", model.new_history_beta.squeeze())
+ if model.split_beta_used:
+ print(f"{model.name} split beta:", model.split_beta.view(-1))
+ print(f"{model.name} bias:", model.split_bias)
+
+
+def save_best_intermediate_betas(self, relus, idx):
+ for layer in relus:
+ # The history split and current split is handled seperatedly.
+ if layer.history_beta_used:
+ # Each key in history_intermediate_betas for this layer is a dictionary, with all other pre-relu layers' names.
+ for k, v in layer.history_intermediate_betas.items():
+ # This is a tensor with shape (batch, *intermediate_layer_shape, number_of_beta)
+ self.best_intermediate_betas[layer.name]['history'][k]["lb"][idx] = v["lb"][idx]
+ self.best_intermediate_betas[layer.name]['history'][k]["ub"][idx] = v["ub"][idx]
+ if layer.split_beta_used:
+ for k, v in layer.split_intermediate_betas.items():
+ # This is a tensor with shape (batch, *intermediate_layer_shape, 1)
+ self.best_intermediate_betas[layer.name]['split'][k]["lb"][idx] = v["lb"][idx]
+ self.best_intermediate_betas[layer.name]['split'][k]["ub"][idx] = v["ub"][idx]
+ if layer.single_beta_used:
+ for k, v in layer.single_intermediate_betas.items():
+ self.best_intermediate_betas[layer.name]['single'][k]["lb"][idx] = v["lb"][idx]
+ self.best_intermediate_betas[layer.name]['single'][k]["ub"][idx] = v["ub"][idx]
\ No newline at end of file
diff --git a/auto_LiRPA/bound_general.py b/auto_LiRPA/bound_general.py
index 27ceeae..d02ebe0 100644
--- a/auto_LiRPA/bound_general.py
+++ b/auto_LiRPA/bound_general.py
@@ -1,11 +1,9 @@
-import time
-import os
import numpy as np
-from collections import OrderedDict, deque, defaultdict
+import warnings
+from collections import OrderedDict, deque
import torch
-import torch.optim as optim
-from torch.nn import DataParallel, Parameter, parameter
+from torch.nn import Parameter
from .bound_op_map import bound_op_map
from .bound_ops import *
@@ -13,12 +11,11 @@
from .parse_graph import parse_module
from .perturbations import *
from .utils import *
-from .adam_element_lr import AdamElementLR
+from .patches import Patches
-import warnings
-warnings.simplefilter("once")
+warnings.simplefilter('once')
class BoundedModule(nn.Module):
"""Bounded module with support for automatically computing bounds.
@@ -26,93 +23,203 @@ class BoundedModule(nn.Module):
Args:
model (nn.Module): The original model to be wrapped by BoundedModule.
- global_input (tuple): A dummy input to the original model. The shape of
- the dummy input should be consistent with the actual input to the model
+ global_input (tuple): A dummy input to the original model. The shape of
+ the dummy input should be consistent with the actual input to the model
except for the batch dimension.
- bound_opts (dict): Options for bounds. See
+ bound_opts (dict): Options for bounds. See
`Bound Options `_.
- device (str or torch.device): Device of the bounded module.
- If 'auto', the device will be automatically inferred from the device of
+ device (str or torch.device): Device of the bounded module.
+ If 'auto', the device will be automatically inferred from the device of
parameters in the original model or the dummy input.
- custom_ops (dict): A dictionary of custom operators.
- The dictionary maps operator names to their corresponding bound classes
+ custom_ops (dict): A dictionary of custom operators.
+ The dictionary maps operator names to their corresponding bound classes
(subclasses of `Bound`).
"""
- def __init__(self, model, global_input, bound_opts=None, auto_batch_dim=True, device='auto',
- verbose=False, custom_ops={}):
- super(BoundedModule, self).__init__()
+ def __init__(self, model, global_input, bound_opts=None,
+ device='auto', verbose=False, custom_ops=None):
+ super().__init__()
if isinstance(model, BoundedModule):
for key in model.__dict__.keys():
setattr(self, key, getattr(model, key))
return
+
+ self.global_input = global_input
+ self.ori_training = model.training
+ self.check_incompatible_nodes(model)
+
if bound_opts is None:
bound_opts = {}
# Default options.
- default_bound_opts = {'ibp_relative': False, 'conv_mode': 'patches', 'sparse_intermediate_bounds': True, 'sparse_conv_intermediate_bounds': True}
+ default_bound_opts = {
+ 'conv_mode': 'patches',
+ 'sparse_intermediate_bounds': True,
+ 'sparse_conv_intermediate_bounds': True,
+ 'sparse_intermediate_bounds_with_ibp': True,
+ 'sparse_features_alpha': True,
+ 'sparse_spec_alpha': True,
+ 'minimum_sparsity': 0.9,
+ 'enable_opt_interm_bounds': False,
+ 'crown_batch_size': np.inf,
+ 'forward_refinement': False,
+ 'dynamic_forward': False,
+ 'forward_max_dim': int(1e9),
+ # Do not share alpha for conv layers.
+ 'use_full_conv_alpha': True,
+ # Threshold for number of unstable neurons for each layer to disable
+ # use_full_conv_alpha.
+ 'use_full_conv_alpha_thresh': 512,
+ 'verbosity': 1 if verbose else 0,
+ }
default_bound_opts.update(bound_opts)
self.bound_opts = default_bound_opts
self.verbose = verbose
- self.custom_ops = custom_ops
- self.auto_batch_dim = auto_batch_dim
+ self.custom_ops = custom_ops if custom_ops is not None else {}
if device == 'auto':
try:
self.device = next(model.parameters()).device
- except StopIteration: # Model has no parameters. We use the device of input tensor.
+ except StopIteration:
+ # Model has no parameters. We use the device of input tensor.
self.device = global_input.device
else:
self.device = device
- self.global_input = global_input
- self.ibp_relative = self.bound_opts.get('ibp_relative', False)
self.conv_mode = self.bound_opts.get('conv_mode', 'patches')
- if auto_batch_dim:
- # logger.warning('Using automatic batch dimension inferring, which may not be correct')
- self.init_batch_size = -1
+ # Cached IBP results which may be reused
+ self.ibp_lower, self.ibp_upper = None, None
+
+ self.optimizable_activations = []
+ self.relus = [] # save relu layers for convenience
state_dict_copy = copy.deepcopy(model.state_dict())
object.__setattr__(self, 'ori_state_dict', state_dict_copy)
model.to(self.device)
- self.final_shape = model(*unpack_inputs(global_input, device=self.device)).shape
+ self.final_shape = model(
+ *unpack_inputs(global_input, device=self.device)).shape
self.bound_opts.update({'final_shape': self.final_shape})
self._convert(model, global_input)
self._mark_perturbed_nodes()
# set the default values here
- optimize_bound_args = {'ob_iteration': 20, 'ob_beta': False, 'ob_alpha': True, 'ob_alpha_share_slopes': False,
- 'ob_opt_coeffs': False, 'ob_opt_bias': False,
- 'ob_optimizer': 'adam', 'ob_verbose': 0,
- 'ob_keep_best': True, 'ob_update_by_layer': True, 'ob_lr': 0.5,
- 'ob_lr_beta': 0.05, 'ob_init': True,
- 'ob_single_node_split': True, 'ob_lr_intermediate_beta': 0.1,
- 'ob_lr_coeffs': 0.01, 'ob_intermediate_beta': False, 'ob_intermediate_refinement_layers': [-1],
- 'ob_loss_reduction_func': reduction_sum,
- 'ob_stop_criterion_func': lambda x: False,
- 'ob_input_grad': False,
- 'ob_lr_decay': 0.98 }
+ optimize_bound_args = {
+ 'enable_alpha_crown': True, # Enable optimization of alpha.
+ 'enable_beta_crown': False, # Enable beta split constraint.
+ 'iteration': 20, # Number of alpha/beta optimization iterations.
+ # Share some alpha variables to save memory at the cost of slightly
+ # looser bounds.
+ 'use_shared_alpha': False,
+ # Optimize coeffs during intermediate_refinement.
+ 'opt_coeffs': False,
+ # Optimize constraint bias during intermediate_refinement.
+ 'opt_bias': False,
+ # Optimizer used for alpha and beta optimization.
+ 'optimizer': 'adam',
+ # Save best results of alpha/beta/bounds during optimization.
+ 'keep_best': True,
+ # Only optimize bounds of last layer during alpha/beta CROWN.
+ 'fix_intermediate_layer_bounds': True,
+ # Learning rate for the optimizable parameter alpha in alpha-CROWN.
+ 'lr_alpha': 0.5,
+ # Learning rate for the optimizable parameter beta in beta-CROWN.
+ 'lr_beta': 0.05,
+ 'lr_cut_beta': 5e-3, # Learning rate for optimizing cut betas.
+ # Initial alpha variables by calling CROWN once.
+ 'init_alpha': True,
+ # Only split single nodes in branch and bound.
+ 'single_node_split': True,
+ # Learning rate for intermediate layer beta for refinement.
+ 'lr_intermediate_beta': 0.1,
+ 'lr_coeffs': 0.01, # Learning rate for coeffs for refinement
+ # Optimize constraint bias in compute bounds.
+ 'intermediate_beta': False,
+ # Layers to be refined, separated by commas.
+ # -1 means preactivation before last relu.
+ 'intermediate_refinement_layers': [-1],
+ # When batch size is not 1, this reduction function is applied to
+ # reduce the bounds into a scalar.
+ 'loss_reduction_func': reduction_sum,
+ # Criteria function of early stop.
+ 'stop_criterion_func': lambda x: False,
+ # Learning rate decay factor during bounds optimization.
+ 'lr_decay': 0.98,
+ # Number of iterations that we will start considering early stop
+ # if tracking no improvement.
+ 'early_stop_patience': 10,
+ # Start to save optimized best bounds
+ # when current_iteration > int(iteration*start_save_best)
+ 'start_save_best': 0.5,
+ # Use double fp (float64) at the last iteration in alpha/beta CROWN.
+ 'use_float64_in_last_iteration': False,
+ # Prune verified domain within iteration.
+ 'pruning_in_iteration': False,
+ # Percentage of the minimum domains that can apply pruning.
+ 'pruning_in_iteration_threshold': 0.2,
+ # For specification that will output multiple bounds for one
+ # property, we use this function to prune them.
+ 'multi_spec_keep_func': lambda x: True
+ }
+
# change by bound_opts
- optimize_bound_args.update(self.bound_opts.get('optimize_bound_args', {}))
+ optimize_bound_args.update(
+ self.bound_opts.get('optimize_bound_args', {}))
self.bound_opts.update({'optimize_bound_args': optimize_bound_args})
self.next_split_hint = [] # Split hints, used in beta optimization.
- self.relus = [] # save relu layers for convenience
- for l in self._modules.values():
- if isinstance(l, BoundRelu):
- self.relus.append(l)
- self.optimizable_activations = []
- for l in self._modules.values():
- if isinstance(l, BoundOptimizableActivation):
- self.optimizable_activations.append(l)
-
- # Beta values for all intermediate bounds. Set to None (not used) by default.
+ # Beta values for all intermediate bounds.
+ # Set to None (not used) by default.
self.best_intermediate_betas = None
# Initialization value for intermediate betas.
self.init_intermediate_betas = None
+ # whether using cut
+ self.cut_used = False
+ # a placeholder for cut timestamp, which would be a non-positive int
+ self.cut_timestamp = -1
+
+ # List of operators. When we are computing intermediate bounds for these
+ # ops, we simply use IBP to propagate bounds from its input nodes,
+ # instead of CROWN.
+ self.ibp_intermediate = [BoundRelu, BoundNeg, BoundTranspose]
+
+ # a placeholder to save the latest samplewise mask for
+ # pruning-in-iteration optimization
+ self.last_update_preserve_mask = None
+
+ @property
+ def perturbed_optimizable_activations(self):
+ return [n for n in self.optimizable_activations if n.perturbed]
+
+
+ def check_incompatible_nodes(self, model):
+ """Check whether the model has incompatible nodes that the conversion
+ may be inaccurate"""
+ node_types = [type(m) for m in list(model.modules())]
+
+ if (torch.nn.Dropout in node_types
+ and torch.nn.BatchNorm1d in node_types
+ and self.global_input.shape[0] == 1):
+ print('We cannot support torch.nn.Dropout and torch.nn.BatchNorm1d '
+ 'at the same time!')
+ print('Suggest to use another dummy input which has batch size '
+ 'larger than 1 and set model to train() mode.')
+ return
+
+ if not self.ori_training and torch.nn.Dropout in node_types:
+ print('Dropout operation CANNOT be parsed during conversion when '
+ 'the model is in eval() mode!')
+ print('Set model to train() mode!')
+ self.ori_training = True
+
+ if self.ori_training and torch.nn.BatchNorm1d in node_types:
+ print('BatchNorm1d may raise error during conversion when the model'
+ ' is in train() mode!')
+ print('Set model to eval() mode!')
+ self.ori_training = False
- """Some operations are non-deterministic and deterministic mode will fail. So we temporary disable it."""
def non_deter_wrapper(self, op, *args, **kwargs):
+ """Some operations are non-deterministic and deterministic mode will
+ fail. So we temporary disable it."""
if self.bound_opts.get('deterministic', False):
torch.use_deterministic_algorithms(False)
ret = op(*args, **kwargs)
@@ -128,22 +235,38 @@ def non_deter_index_select(self, *args, **kwargs):
def set_bound_opts(self, new_opts):
for k, v in new_opts.items():
- assert v is not dict, 'only support change optimize_bound_args'
- self.bound_opts[k].update(v)
+ # assert v is not dict, 'only support change optimize_bound_args'
+ if type(v) == dict:
+ self.bound_opts[k].update(v)
+ else:
+ self.bound_opts[k] = v
- def __call__(self, *input, **kwargs):
+ @staticmethod
+ def _get_A_norm(A):
+ if not isinstance(A, (list, tuple)):
+ A = (A, )
+ norms = []
+ for aa in A:
+ if aa is not None:
+ if isinstance(aa, Patches):
+ aa = aa.patches
+ norms.append(aa.abs().sum().item())
+ else:
+ norms.append(None)
+ return norms
- if "method_opt" in kwargs:
- opt = kwargs["method_opt"]
- kwargs.pop("method_opt")
+ def __call__(self, *input, **kwargs):
+ if 'method_opt' in kwargs:
+ opt = kwargs['method_opt']
+ kwargs.pop('method_opt')
else:
- opt = "forward"
+ opt = 'forward'
for kwarg in [
'disable_multi_gpu', 'no_replicas', 'get_property',
'node_class', 'att_name']:
if kwarg in kwargs:
kwargs.pop(kwarg)
- if opt == "compute_bounds":
+ if opt == 'compute_bounds':
return self.compute_bounds(**kwargs)
else:
return self.forward(*input, **kwargs)
@@ -160,28 +283,29 @@ def register_parameter(self, name, param):
"""
if '_parameters' not in self.__dict__:
raise AttributeError(
- "cannot assign parameter before Module.__init__() call")
+ 'cannot assign parameter before Module.__init__() call')
elif not isinstance(name, torch._six.string_classes):
- raise TypeError("parameter name should be a string. "
- "Got {}".format(torch.typename(name)))
+ raise TypeError('parameter name should be a string. '
+ f'Got {torch.typename(name)}')
elif name == '':
- raise KeyError("parameter name can't be empty string \"\"")
+ raise KeyError('parameter name can\'t be empty string')
elif hasattr(self, name) and name not in self._parameters:
- raise KeyError("attribute '{}' already exists".format(name))
+ raise KeyError(f'attribute "{name}" already exists')
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
- raise TypeError("cannot assign '{}' object to parameter '{}' "
- "(torch.nn.Parameter or None required)"
- .format(torch.typename(param), name))
+ raise TypeError(
+ f'cannot assign "{torch.typename(param)}" object to '
+ f'parameter "{name}" '
+ '(torch.nn.Parameter or None required)')
elif param.grad_fn:
raise ValueError(
- "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
- "parameters must be created explicitly. To express '{0}' "
- "as a function of another Tensor, compute the value in "
- "the forward() method.".format(name))
+ f'Cannot assign non-leaf Tensor to parameter "{name}". Model '
+ 'parameters must be created explicitly. To express "{name}" '
+ 'as a function of another Tensor, compute the value in '
+ 'the forward() method.')
else:
self._parameters[name] = param
@@ -191,12 +315,13 @@ def load_state_dict(self, state_dict, strict=False):
for k, v in state_dict.items():
if k in self.node_name_map:
new_dict[self.node_name_map[k]] = v
- return super(BoundedModule, self).load_state_dict(new_dict, strict=strict)
+ return super().load_state_dict(new_dict, strict=strict)
def _named_members(self, get_members_fn, prefix='', recurse=True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
- modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
+ modules = self.named_modules(prefix=prefix) if recurse else [
+ (prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
@@ -219,75 +344,90 @@ def eval(self):
for node in self._modules.values():
node.eval()
- def forward(self, *x, final_node_name=None, clear_forward_only=False):
+ def to(self, *args, **kwargs):
+ # Moves and/or casts some attributes except pytorch will do by default.
+ for node in self._modules.values():
+ for attr in ['lower', 'upper', 'forward_value', 'd', 'lA',]:
+ if hasattr(node, attr):
+ this_attr = getattr(node, attr)
+ if isinstance(this_attr, torch.Tensor):
+ # print(node, attr)
+ this_attr = this_attr.to(*args, **kwargs)
+ setattr(node, attr, this_attr)
+
+ if hasattr(node, 'interval'):
+ # construct new interval
+ this_attr = getattr(node, 'interval')
+ setattr(node, 'interval', (this_attr[0].to(
+ *args, **kwargs), this_attr[1].to(*args, **kwargs)))
+
+ return super().to(*args, **kwargs)
+
+ def __getitem__(self, name):
+ return self._modules[name]
+
+ def final_node(self):
+ return self[self.final_name]
+
+ def get_forward_value(self, node):
+ """ Recursively get `forward_value` for `node` and its parent nodes"""
+ if getattr(node, 'forward_value', None) is not None:
+ return node.forward_value
+ inputs = [self.get_forward_value(inp) for inp in node.inputs]
+ for inp in node.inputs:
+ node.from_input = node.from_input or inp.from_input
+ node.input_shape = inputs[0].shape if len(inputs) > 0 else None
+ fv = node.forward(*inputs)
+ if isinstance(fv, (torch.Size, tuple)):
+ fv = torch.tensor(fv, device=self.device)
+ node.forward_value = fv
+ node.output_shape = fv.shape
+ # In most cases, the batch dimension is just the first dimension
+ # if the node depends on input. Otherwise if the node doesn't
+ # depend on input, there is no batch dimension.
+ node.batch_dim = 0 if node.from_input else -1
+ # Unperturbed node but it is not a root node.
+ # Save forward_value to value. (Can be used in forward bounds.)
+ if not node.from_input and len(node.inputs) > 0:
+ node.value = node.forward_value
+ return fv
+
+ def forward(self, *x, final_node_name=None, clear_forward_only=False,
+ reset_perturbed_nodes=True):
r"""Standard forward computation for the network.
Args:
x (tuple or None): Input to the model.
- final_node_name (str, optional): The name of the final node in the model. The value
- on the corresponding node will be returned.
+ final_node_name (str, optional): The name of the final node in the
+ model. The value on the corresponding node will be returned.
+
+ clear_forward_only (bool, default `False`): Whether only standard
+ forward values stored on the nodes should be cleared. If `True`,
+ only standard forward values stored on the nodes will be cleared.
+ Otherwise, bound information on the nodes will also be cleared.
- clear_forward_only (bool, default `False`): Whether only standard forward values stored
- on the nodes should be cleared. If `True`, only standard forward values stored on the
- nodes will be cleared. Otherwise, bound information on the nodes will also be cleared.
+ reset_perturbed_nodes (bool, default `True`): Mark all perturbed
+ nodes with input perturbations. When set to `True`, it may
+ accidentally clear all .perturbed properties for intermediate
+ nodes.
Returns:
- output: The output of the model, or if `final_node_name` is not `None`, return the
- value on the corresponding node instead.
+ output: The output of the model, or if `final_node_name` is not
+ `None`, return the value on the corresponding node instead.
"""
-
- self._set_input(*x, clear_forward_only=clear_forward_only)
-
- degree_in = {}
- queue = deque()
- for key in self._modules.keys():
- l = self._modules[key]
- degree_in[l.name] = len(l.input_name)
- if degree_in[l.name] == 0:
- queue.append(l)
- forward_values = {}
-
- final_output = None
- while len(queue) > 0:
- l = queue.popleft()
- inp = [forward_values[l_pre] for l_pre in l.input_name]
- for l_pre in l.inputs:
- l.from_input = l.from_input or l_pre.from_input
- fv = l.forward(*inp)
- if isinstance(fv, torch.Size) or isinstance(fv, tuple):
- fv = torch.tensor(fv, device=self.device)
- object.__setattr__(l, 'forward_value', fv)
- # infer batch dimension
- if not hasattr(l, 'batch_dim'):
- inp_batch_dim = [l_pre.batch_dim for l_pre in l.inputs]
- try:
- l.batch_dim = l.infer_batch_dim(self.init_batch_size, *inp_batch_dim)
- except:
- raise Exception(
- 'Fail to infer the batch dimension of ({})[{}]: forward_value shape {}, input batch dimensions {}'.format(
- l, l.name, l.forward_value.shape, inp_batch_dim))
- forward_values[l.name] = l.forward_value
-
- # Unperturbed node but it is not a root node. Save forward_value to value.
- # (Can be used in forward bounds.)
- if not l.from_input and len(l.inputs) > 0:
- l.value = l.forward_value
-
- for l_next in l.output_name:
- degree_in[l_next] -= 1
- if degree_in[l_next] == 0: # all inputs of this node have already set
- queue.append(self._modules[l_next])
-
+ self._set_input(*x, clear_forward_only=clear_forward_only,
+ reset_perturbed_nodes=reset_perturbed_nodes)
if final_node_name:
- return forward_values[final_node_name]
+ return self.get_forward_value(self[final_node_name])
else:
- out = deque([forward_values[n] for n in self.output_name])
+ out = deque([self.get_forward_value(self[n])
+ for n in self.output_name])
def _fill_template(template):
if template is None:
return out.popleft()
- elif isinstance(template, list) or isinstance(template, tuple):
+ elif isinstance(template, (list, tuple)):
res = []
for t in template:
res.append(_fill_template(t))
@@ -302,84 +442,111 @@ def _fill_template(template):
return _fill_template(self.output_template)
- """Mark the graph nodes and determine which nodes need perturbation."""
def _mark_perturbed_nodes(self):
+ """Mark the graph nodes and determine which nodes need perturbation."""
degree_in = {}
queue = deque()
# Initially the queue contains all "root" nodes.
for key in self._modules.keys():
- l = self._modules[key]
- degree_in[l.name] = len(l.input_name)
+ l = self[key]
+ degree_in[l.name] = len(l.inputs)
if degree_in[l.name] == 0:
queue.append(l) # in_degree ==0 -> root node
while len(queue) > 0:
l = queue.popleft()
- # Obtain all output node, and add the output nodes to the queue if all its input nodes have been visited.
- # the initial "perturbed" property is set in BoundInput or BoundParams object, depending on ptb.
+ # Obtain all output node, and add the output nodes to the queue if
+ # all its input nodes have been visited.
+ # The initial "perturbed" property is set in BoundInput or
+ # BoundParams object, depending on ptb.
for name_next in l.output_name:
- node_next = self._modules[name_next]
+ node_next = self[name_next]
if isinstance(l, BoundShape):
- # Some nodes like Shape, even connected, do not really propagate bounds.
+ # Some nodes like Shape, even connected,
+ # do not really propagate bounds.
# TODO: make this a property of node?
pass
else:
- # The next node is perturbed if it is already perturbed, or this node is perturbed.
+ # The next node is perturbed if it is already perturbed,
+ # or this node is perturbed.
node_next.perturbed = node_next.perturbed or l.perturbed
degree_in[name_next] -= 1
- if degree_in[name_next] == 0: # all inputs of this node have been visited, now put it in queue.
+ # all inputs of this node have been visited,
+ # now put it in queue.
+ if degree_in[name_next] == 0:
queue.append(node_next)
return
- def _clear_and_set_new(self, new_interval, clear_forward_only=False):
+ def _clear_and_set_new(
+ self, intermediate_layer_bounds, clear_forward_only=False,
+ reset_perturbed_nodes=True):
for l in self._modules.values():
if hasattr(l, 'linear'):
if isinstance(l.linear, tuple):
for item in l.linear:
- del (item)
+ del item
delattr(l, 'linear')
+ if hasattr(l, 'patch_size'):
+ l.patch_size = {}
+
if clear_forward_only:
if hasattr(l, 'forward_value'):
- delattr(l, 'forward_value')
+ delattr(l, 'forward_value')
else:
- for attr in ['lower', 'upper', 'interval', 'forward_value', 'd', 'lA', 'lower_d']:
+ for attr in [
+ 'lower', 'upper', 'interval', 'forward_value', 'd',
+ 'lA', 'lower_d']:
if hasattr(l, attr):
delattr(l, attr)
- for attr in ['zero_backward_coeffs_l', 'zero_backward_coeffs_u', 'zero_lA_mtx', 'zero_uA_mtx']:
+ for attr in [
+ 'zero_backward_coeffs_l', 'zero_backward_coeffs_u',
+ 'zero_lA_mtx', 'zero_uA_mtx']:
setattr(l, attr, False)
# Given an interval here to make IBP/CROWN start from this node
- if new_interval is not None and l.name in new_interval.keys():
- l.interval = tuple(new_interval[l.name][:2])
- l.lower = new_interval[l.name][0]
- l.upper = new_interval[l.name][1]
+ if (intermediate_layer_bounds is not None
+ and l.name in intermediate_layer_bounds.keys()):
+ l.interval = tuple(intermediate_layer_bounds[l.name][:2])
+ l.lower = intermediate_layer_bounds[l.name][0]
+ l.upper = intermediate_layer_bounds[l.name][1]
+ if l.lower is not None:
+ l.lower = l.lower.detach().requires_grad_(False)
+ if l.upper is not None:
+ l.upper = l.upper.detach().requires_grad_(False)
# Mark all nodes as non-perturbed except for weights.
- if not hasattr(l, 'perturbation') or l.perturbation is None:
- l.perturbed = False
-
- def _set_input(self, *x, new_interval=None, clear_forward_only=False):
- self._clear_and_set_new(new_interval=new_interval, clear_forward_only=clear_forward_only)
+ if reset_perturbed_nodes:
+ if not hasattr(l, 'perturbation') or l.perturbation is None:
+ l.perturbed = False
+
+ # Clear operator-specific attributes
+ l.clear()
+
+ def _set_input(
+ self, *x, intermediate_layer_bounds=None,
+ clear_forward_only=False, reset_perturbed_nodes=True):
+ self._clear_and_set_new(
+ intermediate_layer_bounds=intermediate_layer_bounds,
+ clear_forward_only=clear_forward_only,
+ reset_perturbed_nodes=reset_perturbed_nodes)
inputs_unpacked = unpack_inputs(x)
for name, index in zip(self.input_name, self.input_index):
- node = self._modules[name]
+ if index is None:
+ continue
+ node = self[name]
node.value = inputs_unpacked[index]
if isinstance(node.value, (BoundedTensor, BoundedParameter)):
node.perturbation = node.value.ptb
else:
node.perturbation = None
# Mark all perturbed nodes.
- self._mark_perturbed_nodes()
- if self.init_batch_size == -1:
- # Automatic batch dimension inferring: get the batch size from
- # the first dimension of the first input tensor.
- self.init_batch_size = inputs_unpacked[0].shape[0]
+ if reset_perturbed_nodes:
+ self._mark_perturbed_nodes()
def _get_node_input(self, nodesOP, nodesIn, node):
ret = []
ori_names = []
for i in range(len(node.inputs)):
- found = False
for op in nodesOP:
if op.name == node.inputs[i]:
ret.append(op.bound_node)
@@ -392,84 +559,108 @@ def _get_node_input(self, nodesOP, nodesIn, node):
ori_names.append(io.ori_name)
break
if len(ret) <= i:
- raise ValueError('cannot find inputs of node: {}'.format(node.name))
- return ret, ori_names
+ raise ValueError(f'cannot find inputs of node: {node.name}')
+ return ret
- # move all tensors in the object to a specified device
- def _to(self, obj, device):
- if isinstance(obj, torch.Tensor):
- return obj.to(device)
+ def _to(self, obj, dest, inplace=False):
+ """ Move all tensors in the object to a specified dest
+ (device or dtype). The inplace=True option is available for dict."""
+ if obj is None:
+ return obj
+ elif isinstance(obj, torch.Tensor):
+ return obj.to(dest)
+ elif isinstance(obj, Patches):
+ return obj.patches.to(dest)
elif isinstance(obj, tuple):
- return tuple([self._to(item, device) for item in obj])
+ return tuple([self._to(item, dest) for item in obj])
elif isinstance(obj, list):
- return list([self._to(item, device) for item in obj])
+ return list([self._to(item, dest) for item in obj])
elif isinstance(obj, dict):
- res = {}
- for key in obj:
- res[key] = self._to(obj[key], device)
- return res
+ if inplace:
+ for k, v in obj.items():
+ obj[k] = self._to(v, dest, inplace=True)
+ return obj
+ else:
+ return {k: self._to(v, dest) for k, v in obj.items()}
else:
raise NotImplementedError(type(obj))
def _convert_nodes(self, model, global_input):
+ r"""
+ Returns:
+ nodesOP (list): List of operator nodes
+ nodesIn (list): List of input nodes
+ nodesOut (list): List of output nodes
+ template (object): Template to specify the output format
+ """
global_input_cpu = self._to(global_input, 'cpu')
- model.train()
+ if self.ori_training:
+ model.train()
+ else:
+ model.eval()
model.to('cpu')
- nodesOP, nodesIn, nodesOut, template = parse_module(model, global_input_cpu)
+ nodesOP, nodesIn, nodesOut, template = parse_module(
+ model, global_input_cpu)
model.to(self.device)
for i in range(0, len(nodesIn)):
if nodesIn[i].param is not None:
- nodesIn[i] = nodesIn[i]._replace(param=nodesIn[i].param.to(self.device))
+ nodesIn[i] = nodesIn[i]._replace(
+ param=nodesIn[i].param.to(self.device))
global_input_unpacked = unpack_inputs(global_input)
# Convert input nodes and parameters.
for i, n in enumerate(nodesIn):
if n.input_index is not None:
nodesIn[i] = nodesIn[i]._replace(bound_node=BoundInput(
- nodesIn[i].inputs, nodesIn[i].name, nodesIn[i].ori_name,
+ ori_name=nodesIn[i].ori_name,
value=global_input_unpacked[nodesIn[i].input_index],
- perturbation=nodesIn[i].perturbation))
+ perturbation=nodesIn[i].perturbation,
+ input_index=n.input_index))
else:
- bound_class = BoundParams if isinstance(nodesIn[i].param, nn.Parameter) else BoundBuffers
+ bound_class = BoundParams if isinstance(
+ nodesIn[i].param, nn.Parameter) else BoundBuffers
nodesIn[i] = nodesIn[i]._replace(bound_node=bound_class(
- nodesIn[i].inputs, nodesIn[i].name, nodesIn[i].ori_name,
- value=nodesIn[i].param, perturbation=nodesIn[i].perturbation))
+ ori_name=nodesIn[i].ori_name,
+ value=nodesIn[i].param,
+ perturbation=nodesIn[i].perturbation))
unsupported_ops = []
# Convert other operation nodes.
for n in range(len(nodesOP)):
attr = nodesOP[n].attr
- inputs, ori_names = self._get_node_input(nodesOP, nodesIn, nodesOP[n])
-
+ inputs = self._get_node_input(nodesOP, nodesIn, nodesOP[n])
try:
if nodesOP[n].op in self.custom_ops:
op = self.custom_ops[nodesOP[n].op]
elif nodesOP[n].op in bound_op_map:
op = bound_op_map[nodesOP[n].op]
elif nodesOP[n].op.startswith('aten::ATen'):
- op = eval('BoundATen{}'.format(attr['operator'].capitalize()))
+ op = globals()[f'BoundATen{attr["operator"].capitalize()}']
elif nodesOP[n].op.startswith('onnx::'):
- op = eval('Bound{}'.format(nodesOP[n].op[6:]))
+ op = globals()[f'Bound{nodesOP[n].op[6:]}']
else:
raise KeyError
except (NameError, KeyError):
unsupported_ops.append(nodesOP[n])
- logger.error('The node has an unsupported operation: {}'.format(nodesOP[n]))
+ logger.error('The node has an unsupported operation: %s',
+ nodesOP[n])
continue
- if nodesOP[n].op == 'onnx::BatchNormalization':
- # BatchNormalization node needs model.training flag to set running mean and vars
- # set training=False to avoid wrongly updating running mean/vars during bound wrapper
- nodesOP[n] = nodesOP[n]._replace(
- bound_node=op(
- nodesOP[n].inputs, nodesOP[n].name, None, attr,
- inputs, nodesOP[n].output_index, self.bound_opts, self.device, False))
+ attr['device'] = self.device
+
+ # FIXME generalize
+ if (nodesOP[n].op == 'onnx::BatchNormalization'
+ or getattr(op, 'TRAINING_FLAG', False)):
+ # BatchNormalization node needs model.training flag to set
+ # running mean and vars set training=False to avoid wrongly
+ # updating running mean/vars during bound wrapper
+ nodesOP[n] = nodesOP[n]._replace(bound_node=op(
+ attr, inputs, nodesOP[n].output_index, self.bound_opts,
+ False))
else:
- nodesOP[n] = nodesOP[n]._replace(
- bound_node=op(
- nodesOP[n].inputs, nodesOP[n].name, None, attr,
- inputs, nodesOP[n].output_index, self.bound_opts, self.device))
+ nodesOP[n] = nodesOP[n]._replace(bound_node=op(
+ attr, inputs, nodesOP[n].output_index, self.bound_opts))
if unsupported_ops:
logger.error('Unsupported operations:')
@@ -477,51 +668,83 @@ def _convert_nodes(self, model, global_input):
logger.error(f'Name: {n.op}, Attr: {n.attr}')
raise NotImplementedError('There are unsupported operations')
+ for node in nodesIn + nodesOP:
+ node.bound_node.input_name = node.inputs
+ node.bound_node.name = node.name
+
+ nodes_dict = {}
+ for node in nodesOP + nodesIn:
+ nodes_dict[node.name] = node.bound_node
+ nodesOP = [n.bound_node for n in nodesOP]
+ nodesIn = [n.bound_node for n in nodesIn]
+ nodesOut = [nodes_dict[n] for n in nodesOut]
+
return nodesOP, nodesIn, nodesOut, template
def _build_graph(self, nodesOP, nodesIn, nodesOut, template):
- nodes = []
- for node in nodesOP + nodesIn:
- assert (node.bound_node is not None)
- nodes.append(node.bound_node)
# We were assuming that the original model had only one output node.
- # When there are multiple output nodes, this seems to be the first output element.
- # In this case, we are assuming that we aim to compute the bounds for the first
- # output element by default.
- self.final_name = nodesOut[0]
+ # When there are multiple output nodes, this seems to be the first
+ # output element. In this case, we are assuming that we aim to compute
+ # the bounds for the first output element by default.
+ self.final_name = nodesOut[0].name
self.input_name, self.input_index, self.root_name = [], [], []
- for node in nodesIn:
- self.root_name.append(node.name)
- if node.input_index is not None:
- self.input_name.append(node.name)
- self.input_index.append(node.input_index)
- self.output_name = nodesOut
+ self.output_name = [n.name for n in nodesOut]
self.output_template = template
- for l in nodes:
- self._modules[l.name] = l
- l.output_name = []
- if isinstance(l.input_name, str):
- l.input_name = [l.input_name]
- for l in nodes:
- for l_pre in l.input_name:
- self._modules[l_pre].output_name.append(l.name)
- for l in nodes:
- if self.conv_mode != 'patches' and len(l.input_name) == 0:
- if not l.name in self.root_name:
- # Add independent nodes that do not appear in `nodesIn`.
- # Note that these nodes are added in the last, since
- # order matters in the current implementation because
- # `root[0]` is used in some places.
- self.root_name.append(l.name)
+ for node in nodesIn:
+ self.add_input_node(node, index=node.input_index)
+ self.add_nodes(nodesOP)
+ if self.conv_mode == 'patches':
+ self.root_name = [node.name for node in nodesIn]
+
+ # Make sure the nodes already have `name` and `input_name`
+ def add_nodes(self, nodes):
+ nodes = [(node if isinstance(node, Bound) else node.bound_node)
+ for node in nodes]
+ for node in nodes:
+ self._modules[node.name] = node
+ node.output_name = []
+ if not hasattr(node, 'input_name'):
+ node.input_name = []
+ if isinstance(node.input_name, str):
+ node.input_name = [node.input_name]
+ if len(node.inputs) == 0:
+ self.root_name.append(node.name)
+ for node in nodes:
+ for l_pre in node.inputs:
+ l_pre.output_name.append(node.name)
+ for node in nodes:
+ if isinstance(node, BoundOptimizableActivation):
+ self.optimizable_activations.append(node)
+ if isinstance(node, BoundRelu):
+ self.relus.append(node)
+
+ def add_input_node(self, node, index=None):
+ self.add_nodes([node])
+ self.input_name.append(node.name)
+ # default value for input_index
+ if index == 'auto':
+ index = max([0] + [(i + 1)
+ for i in self.input_index if i is not None])
+ self.input_index.append(index)
+
+ def rename_nodes(self, nodesOP, nodesIn, rename_dict):
+ def rename(node):
+ node.name = rename_dict[node.name]
+ node.input_name = [
+ rename_dict[name] for name in node.input_name]
+ return node
+ for i in range(len(nodesOP)):
+ nodesOP[i] = rename(nodesOP[i])
+ for i in range(len(nodesIn)):
+ nodesIn[i] = rename(nodesIn[i])
def _split_complex(self, nodesOP, nodesIn):
finished = True
for n in range(len(nodesOP)):
- if hasattr(nodesOP[n].bound_node, 'complex') and \
- nodesOP[n].bound_node.complex:
+ if hasattr(nodesOP[n], 'complex') and nodesOP[n].complex:
finished = False
- _nodesOP, _nodesIn, _nodesOut, _template = self._convert_nodes(
- nodesOP[n].bound_node.model, nodesOP[n].bound_node.input)
+ _nodesOP, _nodesIn, _nodesOut, _ = self._convert_nodes(
+ nodesOP[n].model, nodesOP[n].input)
# assuming each supported complex operation only has one output
assert len(_nodesOut) == 1
@@ -529,37 +752,28 @@ def _split_complex(self, nodesOP, nodesIn):
rename_dict = {}
for node in _nodesOP + _nodesIn:
rename_dict[node.name] = name_base + node.name
- num_inputs = len(nodesOP[n].bound_node.input)
+ num_inputs = len(nodesOP[n].inputs)
for i in range(num_inputs):
- rename_dict[_nodesIn[i].name] = nodesOP[n].inputs[i]
+ rename_dict[_nodesIn[i].name] = nodesOP[n].input_name[i]
rename_dict[_nodesOP[-1].name] = nodesOP[n].name
- def rename(node):
- node.bound_node.name = rename_dict[node.bound_node.name]
- node.bound_node.input_name = [
- rename_dict[name] for name in node.bound_node.input_name]
- node = node._replace(
- name=rename_dict[node.name],
- inputs=node.bound_node.input_name)
- return node
-
- for i in range(len(_nodesOP)):
- _nodesOP[i] = rename(_nodesOP[i])
- for i in range(len(_nodesIn)):
- _nodesIn[i] = rename(_nodesIn[i])
+ self.rename_nodes(_nodesOP, _nodesIn, rename_dict)
+
output_name = _nodesOP[-1].name
- # Any input node of some node within the complex node should be
- # replaced with the corresponding input node of the complex node.
+ # Any input node of some node within the complex node should be
+ # replaced with the complex node's corresponding input node.
for node in _nodesOP:
- for i in range(len(node.bound_node.inputs)):
- if node.bound_node.input_name[i] in nodesOP[n].inputs:
- index = nodesOP[n].inputs.index(node.bound_node.input_name[i])
- node.bound_node.inputs[i] = nodesOP[n].bound_node.inputs[index]
- # For any output node of this complex node, modify its input node
+ for i in range(len(node.inputs)):
+ if node.input_name[i] in nodesOP[n].input_name:
+ index = nodesOP[n].input_name.index(
+ node.input_name[i])
+ node.inputs[i] = nodesOP[n].inputs[index]
+ # For any output node of this complex node,
+ # modify its input node.
for node in nodesOP:
- if output_name in node.bound_node.input_name:
- index = node.bound_node.input_name.index(output_name)
- node.bound_node.inputs[index] = _nodesOP[-1].bound_node
+ if output_name in node.input_name:
+ index = node.input_name.index(output_name)
+ node.inputs[index] = _nodesOP[-1]
nodesOP = nodesOP[:n] + _nodesOP + nodesOP[(n + 1):]
nodesIn = nodesIn + _nodesIn[num_inputs:]
@@ -568,19 +782,21 @@ def rename(node):
return nodesOP, nodesIn, finished
- """build a dict with {ori_name: name, name: ori_name}"""
- def _get_node_name_map(self, ):
+ def _get_node_name_map(self):
+ """Build a dict with {ori_name: name, name: ori_name}"""
self.node_name_map = {}
for node in self._modules.values():
if isinstance(node, BoundInput) or isinstance(node, BoundParams):
for p in list(node.named_parameters()):
if node.ori_name not in self.node_name_map:
- self.node_name_map[node.ori_name] = node.name + '.' + p[0]
- self.node_name_map[node.name + '.' + p[0]] = node.ori_name
+ name = f'{node.name}.{p[0]}'
+ self.node_name_map[node.ori_name] = name
+ self.node_name_map[name] = node.ori_name
for p in list(node.named_buffers()):
if node.ori_name not in self.node_name_map:
- self.node_name_map[node.ori_name] = node.name + '.' + p[0]
- self.node_name_map[node.name + '.' + p[0]] = node.ori_name
+ name = f'{node.name}.{p[0]}'
+ self.node_name_map[node.ori_name] = name
+ self.node_name_map[name] = node.ori_name
# convert a Pytorch model to a model with bounds
def _convert(self, model, global_input):
@@ -591,7 +807,8 @@ def _convert(self, model, global_input):
global_input = (global_input,)
self.num_global_inputs = len(global_input)
- nodesOP, nodesIn, nodesOut, template = self._convert_nodes(model, global_input)
+ nodesOP, nodesIn, nodesOut, template = self._convert_nodes(
+ model, global_input)
global_input = self._to(global_input, self.device)
while True:
@@ -603,468 +820,314 @@ def _convert(self, model, global_input):
self._get_node_name_map()
- # load self.ori_state_dict again to avoid the running means/vars changed during forward()
+ # Load self.ori_state_dict again to avoid the running means/vars changed
+ # during forward().
self.load_state_dict(self.ori_state_dict)
- model.load_state_dict(self.ori_state_dict)
+ if self.ori_training:
+ model.load_state_dict(self.ori_state_dict)
delattr(self, 'ori_state_dict')
# The final node used in the last time calling `compute_bounds`
self.last_final_node = None
-
- logger.debug('NodesOP:')
- for node in nodesOP:
- logger.debug('{}'.format(node._replace(param=None)))
- logger.debug('NodesIn')
- for node in nodesIn:
- logger.debug('{}'.format(node._replace(param=None)))
+ self.used_nodes = []
if self.verbose:
logger.info('Model converted to support bounds')
- def init_slope(self, x, share_slopes=False, method='backward', c=None, bound_lower=True, bound_upper=True):
- if method != 'forward':
- assert isinstance(x, tuple)
- assert method == 'backward'
- x = x[0]
+ def check_prior_bounds(self, node):
+ if node.prior_checked or not (node.used and node.perturbed):
+ return
- for node in self.optimizable_activations:
- # initialize the parameters
- node.opt_init()
+ for n in node.inputs:
+ self.check_prior_bounds(n)
- with torch.no_grad():
- if method == 'forward':
- l, u = self.compute_bounds(x=(x,), method='forward', bound_lower=bound_lower, bound_upper=bound_upper)
- else:
- l, u = self.compute_bounds(x=(x,), IBP=False, C=c, method='backward', return_A=False, bound_lower=bound_lower, bound_upper=bound_upper)
+ if getattr(node, 'nonlinear', False):
+ for n in node.inputs:
+ self.compute_intermediate_bounds(n, prior_checked=True)
+ for i in getattr(node, 'requires_input_bounds', []):
+ self.compute_intermediate_bounds(
+ node.inputs[i], prior_checked=True)
- init_intermediate_bounds = {}
- for node in self.optimizable_activations:
- if method == 'forward':
- assert not '_forward' in self._modules.keys(), '_forward is a reserved node name'
- assert isinstance(node, BoundRelu), 'Only ReLU is supported for optimizing forward bounds'
- start_nodes = [ ('_forward', 1) ]
- else:
- start_nodes = []
- for nj in self.backward_from[node.name]:
- if nj.name == self.final_name:
- size_final = self.final_shape[-1] if c is None else c.size(1)
- start_nodes.append((self.final_name, size_final))
- continue
- if share_slopes:
- # all intermediate neurons from the same layer share the same set of slopes.
- output_shape = 1
- elif isinstance(node, BoundRelu) and node.patch_size and nj.name in node.patch_size:
- # Patches mode. Use output channel size as the spec size. This still shares some alpha, but better than no sharing.
- # The patch size is [out_ch, batch, out_h, out_w, in_ch, H, W]. We use out_ch as the output shape.
- output_shape = node.patch_size[nj.name][0]
- # print(f'node {nj.name} {nj} use patch size {output_shape} patches size {node.patch_size[nj.name][0]}')
- else:
- output_shape = prod(nj.lower.shape[1:])
- # print(f'node {nj.name} {nj} use regular size {output_shape}')
- start_nodes.append((nj.name, output_shape))
- node.init_opt_parameters(start_nodes)
- node.opt_start()
- init_intermediate_bounds[node.inputs[0].name] = ([node.inputs[0].lower.detach(), node.inputs[0].upper.detach()])
-
- print("alpha-CROWN optimizable variables initialized.")
- return l, u, init_intermediate_bounds
-
- def beta_bias(self):
- batch_size = len(self.relus[-1].split_beta)
- batch = int(batch_size/2)
- bias = torch.zeros((batch_size, 1), device=self.device)
- for m in self.relus:
- if m.split_beta_used:
- bias[:batch] = bias[:batch] + m.split_bias*m.split_beta[:batch]*m.split_c[:batch]
- bias[batch:] = bias[batch:] + m.split_bias*m.split_beta[batch:]*m.split_c[batch:]
- if m.history_beta_used:
- bias = bias + (m.new_history_bias*m.new_history_beta*m.new_history_c).sum(1, keepdim=True)
- # No single node split here, because single node splits do not have bias.
- return bias
-
-
- def get_optimized_bounds(self, x=None, aux=None, C=None, IBP=False, forward=False, method='backward',
- bound_lower=True, bound_upper=False, reuse_ibp=False, return_A=False, final_node_name=None,
- average_A=False, new_interval=None, reference_bounds=None, aux_reference_bounds=None, needed_A_dict=None):
- # optimize CROWN lower bound by alpha and beta
- opts = self.bound_opts['optimize_bound_args']
- iteration = opts['ob_iteration']; beta = opts['ob_beta']; alpha = opts['ob_alpha']
- opt_coeffs = opts['ob_opt_coeffs']; opt_bias = opts['ob_opt_bias']
- verbose = opts['ob_verbose']; opt_choice = opts['ob_optimizer']
- single_node_split = opts['ob_single_node_split']
- keep_best = opts['ob_keep_best']; update_by_layer = opts['ob_update_by_layer']; init = opts['ob_init']
- lr = opts['ob_lr']; lr_beta = opts['ob_lr_beta']
- lr_intermediate_beta = opts['ob_lr_intermediate_beta']
- intermediate_beta_enabled = opts['ob_intermediate_beta']
- lr_decay = opts['ob_lr_decay']; lr_coeffs = opts['ob_lr_coeffs']
- loss_reduction_func = opts['ob_loss_reduction_func']
- stop_criterion_func = opts['ob_stop_criterion_func']
- input_grad = opts['ob_input_grad']
- sparse_intermediate_bounds = self.bound_opts.get('sparse_intermediate_bounds', False)
- # verbose = 1
-
- assert bound_lower != bound_upper, 'we can only optimize lower OR upper bound at one time'
- assert alpha or beta, "nothing to optimize, use compute bound instead!"
-
- if C is not None:
- self.final_shape = C.size()[:2]
- self.bound_opts.update({'final_shape': self.final_shape})
- if init:
- self.init_slope(x, share_slopes=opts['ob_alpha_share_slopes'], method=method, c=C)
-
- alphas = []
- betas = []
- parameters = []
- dense_coeffs_mask = []
-
- for m in self.optimizable_activations:
- if alpha:
- alphas.extend(list(m.alpha.values()))
-
- if alpha:
- # Alpha has shape (2, output_shape, batch_dim, node_shape)
- parameters.append({'params': alphas, 'lr': lr, 'batch_dim': 2})
- # best_alpha is a dictionary of dictionary. Each key is the alpha variable for one relu layer, and each value is a dictionary contains all relu layers after that layer as keys.
- best_alphas = OrderedDict()
- for m in self.optimizable_activations:
- best_alphas[m.name] = {}
- for alpha_m in m.alpha:
- best_alphas[m.name][alpha_m] = m.alpha[alpha_m].clone().detach()
- # We will directly replace the dictionary for each relu layer after optimization, so the saved alpha might not have require_grad=True.
- m.alpha[alpha_m].requires_grad_()
-
- if beta:
- if len(self.relus) != len(self.optimizable_activations):
- raise NotImplementedError("Beta-CROWN for tanh models is not supported yet")
-
- if single_node_split:
- for model in self.relus:
- betas.append(model.sparse_beta)
- else:
- betas = self.beta_params + self.single_beta_params
- if opt_coeffs:
- coeffs = [dense_coeffs["dense"] for dense_coeffs in self.split_dense_coeffs_params] + self.coeffs_params
- dense_coeffs_mask = [dense_coeffs["mask"] for dense_coeffs in self.split_dense_coeffs_params]
- parameters.append({'params': coeffs, 'lr': lr_coeffs})
- best_coeffs = [coeff.clone().detach() for coeff in coeffs]
- if opt_bias:
- biases = self.bias_params
- parameters.append({'params': biases, 'lr': lr_coeffs})
- best_biases = [bias.clone().detach() for bias in biases]
-
- # Beta has shape (batch, max_splits_per_layer)
- parameters.append({'params': betas, 'lr': lr_beta, 'batch_dim': 0})
- best_betas = [b.clone().detach() for b in betas]
-
- start = time.time()
-
-
- if opt_choice == "adam-autolr":
- opt = AdamElementLR(parameters, lr=lr)
- elif opt_choice == "adam":
- opt = optim.Adam(parameters, lr=lr)
- elif opt_choice == 'sgd':
- opt = optim.SGD(parameters, lr=lr, momentum=0.9)
- else:
- raise NotImplementedError(opt_choice)
+ node.prior_checked = True
+
+ def compute_intermediate_bounds(self, node, prior_checked=False):
+ if getattr(node, 'lower', None) is not None:
+ return
- # Create a weight vector to scale learning rate.
- loss_weight = torch.ones(size=(x[0].size(0),), device=x[0].device)
+ logger.debug(f'Getting the bounds of {node}')
- scheduler = optim.lr_scheduler.ExponentialLR(opt, lr_decay)
+ if not prior_checked:
+ self.check_prior_bounds(node)
- last_l = math.inf
- last_total_loss = torch.tensor(1e8, device=x[0].device, dtype=x[0].dtype)
- best_l = torch.zeros([x[0].shape[0], 1], device=x[0].device, dtype=x[0].dtype) + 1e8
+ if not node.perturbed:
+ fv = self.get_forward_value(node)
+ node.interval = node.lower, node.upper = fv, fv
+ return
+ # FIXME check that weight perturbation is not affected
+ # (from_input=True should be set for weights)
+ if not node.from_input and hasattr(node, 'forward_value'):
+ node.lower = node.upper = self.get_forward_value(node)
+ return
- best_intermediate_bounds = [] # TODO: this should be a dictionary to handle more general architectures.
+ reference_bounds = self.reference_bounds
- if sparse_intermediate_bounds and aux_reference_bounds is None and reference_bounds is not None:
- aux_reference_bounds = {}
- for name, (lb, ub) in reference_bounds.items():
- aux_reference_bounds[name] = (lb.detach().clone(), ub.detach().clone())
- if aux_reference_bounds is None:
- aux_reference_bounds = {}
+ if self.use_forward:
+ node.lower, node.upper = self.forward_general(
+ node=node, concretize=True)
+ return
- for i in range(iteration):
- intermediate_constr = None
-
- if not update_by_layer:
- reference_bounds = new_interval # If we still optimize all intermediate neurons, we can use new_interval as reference bounds.
-
- ret = self.compute_bounds(x, aux, C, method=method, IBP=IBP, forward=forward,
- bound_lower=bound_lower, bound_upper=bound_upper, reuse_ibp=reuse_ibp,
- return_A=return_A, final_node_name=final_node_name, average_A=average_A,
- # If we set neuron bounds individually, or if we are optimizing intermediate layer bounds using beta, we do not set new_interval.
- # When intermediate betas are used, we must set new_interval to None because we want to recompute all intermediate layer bounds.
- new_interval=partial_new_interval if beta and intermediate_beta_enabled else new_interval if update_by_layer else None,
- # This is the currently tightest interval, which will be used to pass split constraints when intermediate betas are used.
- reference_bounds=reference_bounds,
- # This is the interval used for checking for unstable neurons.
- aux_reference_bounds=aux_reference_bounds if sparse_intermediate_bounds else None,
- # These are intermediate layer beta variables and their corresponding A matrices and biases.
- intermediate_constr=intermediate_constr, needed_A_dict=needed_A_dict)
-
- if i == 0:
- best_ret = ret
- for model in self.optimizable_activations:
- best_intermediate_bounds.append([model.inputs[0].lower.clone().detach(), model.inputs[0].upper.clone().detach()])
- if sparse_intermediate_bounds:
- aux_reference_bounds[model.inputs[0].name] = best_intermediate_bounds[-1]
-
- ret_l, ret_u = ret[0], ret[1]
-
- if beta and opt_bias and not single_node_split:
- ret_l = ret_l + self.beta_bias()
- ret = (ret_l, ret_u) # where is A matrix?
-
- l = ret_l
- if ret_l is not None and ret_l.shape[1] != 1: # Reduction over the spec dimension.
- l = loss_reduction_func(ret_l)
- u = ret_u
- if ret_u is not None and ret_u.shape[1] != 1:
- u = loss_reduction_func(ret_u)
-
- if True:
- loss_ = l if bound_lower else -u
- stop_criterion = stop_criterion_func(ret_l) if bound_lower else stop_criterion_func(-ret_u)
- total_loss = -1 * loss_
- if type(stop_criterion) == bool:
- loss = total_loss.sum() * (not stop_criterion)
+ #FIXME need clean up
+
+ # assign concretized bound for ReLU layer to save computational cost
+ # FIXME: Put ReLU after reshape will cause problem!
+ if self.check_IBP_intermediate(node):
+ # Intermediate bounds for some operators are directly
+ # computed from their input nodes by IBP
+ # (such as BoundRelu, BoundNeg)
+ logger.debug('IBP propagation for intermediate bounds on %s', node)
+ elif (isinstance(node, BoundReshape)
+ and hasattr(node.inputs[0], 'lower')
+ and hasattr(node.inputs[1], 'value')):
+ # TODO merge this with `check_IBP_intermediate`
+ # Node for input value.
+ val_input = node.inputs[0]
+ # Node for input parameter (e.g., shape, permute)
+ arg_input = node.inputs[1]
+ node.lower = node.forward(val_input.lower, arg_input.value)
+ node.upper = node.forward(val_input.upper, arg_input.value)
+ node.interval = (node.lower, node.upper)
+ else:
+ # For the first linear layer, IBP can give the same tightness
+ # as CROWN.
+ if self.check_IBP_first_linear(node):
+ return
+
+ sparse_intermediate_bounds_with_ibp = self.bound_opts.get(
+ 'sparse_intermediate_bounds_with_ibp', True)
+ # Sparse intermediate bounds can be enabled
+ # if aux_reference_bounds are given.
+ # (this is enabled for ReLU only, and not for other
+ # activations.)
+ sparse_intermediate_bounds = (self.bound_opts.get(
+ 'sparse_intermediate_bounds', False)
+ and isinstance(self[node.output_name[0]], BoundRelu))
+
+ ref_intermediate_lb, ref_intermediate_ub = None, None
+ if sparse_intermediate_bounds:
+ if node.name not in self.aux_reference_bounds:
+ # If aux_reference_bounds are not available,
+ # we can use IBP to compute these bounds.
+ if sparse_intermediate_bounds_with_ibp:
+ with torch.no_grad():
+ # Get IBP bounds for this layer;
+ # we set delete_bounds_after_use=True which does
+ # not save extra intermediate bound tensors.
+ ret_ibp = self.IBP_general(
+ node=node, delete_bounds_after_use=True)
+ ref_intermediate_lb = ret_ibp[0]
+ ref_intermediate_ub = ret_ibp[1]
+ else:
+ sparse_intermediate_bounds = False
else:
- loss = (total_loss * stop_criterion.logical_not()).sum()
-
- with torch.no_grad():
- # Save varibles if this is the best iteration.
- if keep_best and (total_loss < best_l).any():
- # we only pick up the results improved in a batch
- idx = (total_loss < best_l).squeeze()
- best_l[idx] = total_loss[idx]
-
- if ret[0] is not None:
- best_ret[0][idx] = ret[0][idx]
- if ret[1] is not None:
- best_ret[1][idx] = ret[1][idx]
- if return_A:
- best_ret = (best_ret[0], best_ret[1], ret[2])
-
- for ii, model in enumerate(self.optimizable_activations):
- # best_intermediate_bounds.append([model.inputs[0].lower, model.inputs[0].upper])
- best_intermediate_bounds[ii][0][idx] = torch.max(best_intermediate_bounds[ii][0][idx], model.inputs[0].lower[idx])
- best_intermediate_bounds[ii][1][idx] = torch.min(best_intermediate_bounds[ii][1][idx], model.inputs[0].upper[idx])
- if alpha:
- # each alpha has shape (2, output_shape, batch, *shape)
- for alpha_m in model.alpha:
- best_alphas[model.name][alpha_m][:,:,idx] = model.alpha[alpha_m][:,:,idx].clone().detach()
- if beta and single_node_split:
- best_betas[ii][idx] = betas[ii][idx].clone().detach()
-
- if not single_node_split and beta:
- for ii, b in enumerate(betas):
- best_betas[ii][idx] = b[idx].clone().detach()
-
- if opt_coeffs:
- best_coeffs = [co.clone().detach() for co in coeffs] # TODO: idx-wise
- if opt_bias:
- best_biases = [bias.clone().detach() for bias in biases] # TODO: idx-wise
-
-
- if os.environ.get('AUTOLIRPA_DEBUG_OPT', False):
- print(f"****** iter [{i}]",
- f"loss: {loss.item()}, lr: {opt.param_groups[0]['lr']}")
-
- if isinstance(stop_criterion, torch.Tensor) and stop_criterion.all():
- print(f"\nall verified at {i}th iter")
- break
+ aux_bounds = self.aux_reference_bounds[node.name]
+ ref_intermediate_lb, ref_intermediate_ub = aux_bounds
+
+ sparse_C = self.get_sparse_C(
+ node, sparse_intermediate_bounds,
+ ref_intermediate_lb, ref_intermediate_ub)
+ newC, reduced_dim, unstable_idx, unstable_size = sparse_C
+
+ if unstable_idx is None or unstable_size > 0:
+ if self.return_A:
+ node.lower, node.upper, _ = self.backward_general(
+ C=newC, node=node, unstable_idx=unstable_idx,
+ unstable_size=unstable_size)
+ else:
+ # Compute backward bounds only when there are unstable
+ # neurons, or when we don't know which neurons are unstable.
+ node.lower, node.upper = self.backward_general(
+ C=newC, node=node, unstable_idx=unstable_idx,
+ unstable_size=unstable_size)
+
+ if reduced_dim:
+ self.restore_sparse_bounds(
+ node, unstable_idx, unstable_size,
+ ref_intermediate_lb, ref_intermediate_ub)
+
+ # node.lower and node.upper (intermediate bounds) are computed in
+ # the above function. If we have bound references, we set them here
+ # to always obtain a better set of bounds.
+ if node.name in reference_bounds:
+ ref_bounds = reference_bounds[node.name]
+ # Initially, the reference bound and the computed bound can be
+ # exactly the same when intermediate layer beta is 0. This will
+ # prevent gradients flow. So we need a small guard here.
+ if self.intermediate_constr is not None:
+ # Intermediate layer beta is used.
+ # Note that we cannot just take the reference bounds if
+ # they are better - this makes alphas have zero gradients.
+ new_lower = 0.9 * ref_bounds[0] + 0.1 * node.lower
+ new_upper = 0.9 * ref_bounds[1] + 0.1 * node.upper
+ node.lower = torch.max(new_lower, node.lower)
+ node.upper = torch.min(new_upper, node.upper)
+ # Additionally, if the reference bounds say a neuron is
+ # stable, we always keep it. (FIXME: this is for ReLU only).
+ lower_stable = ref_bounds[0] >= 0.
+ node.lower[lower_stable] = ref_bounds[0][lower_stable]
+ upper_stable = ref_bounds[1] <= 0.
+ node.upper[upper_stable] = ref_bounds[1][upper_stable]
+ else:
+ # Set the intermediate layer bounds using reference bounds,
+ # always choosing the tighter one.
+ node.lower = (
+ torch.max(ref_bounds[0], node.lower).detach()
+ - node.lower.detach() + node.lower)
+ node.upper = (
+ node.upper - (node.upper.detach()
+ - torch.min(ref_bounds[1], node.upper).detach()))
+ # Otherwise, we only use reference bounds to check which neurons
+ # are unstable.
+
+ # FIXME (12/28): we should be consistent, and only use
+ # node.interval, do not use node.lower or node.upper!
+ node.interval = (node.lower, node.upper)
- current_lr = []
- for param_group in opt.param_groups:
- current_lr.append(param_group['lr'])
-
- opt.zero_grad(set_to_none=True)
-
- if input_grad and x[0].ptb.x_L.grad is not None:
- x[0].ptb.x_L.grad = None
- x[0].ptb.x_U.grad = None
-
- loss.backward()
-
- if verbose > 0:
- print(f"*** iter [{i}]\n", f"loss: {loss.item()}", total_loss.squeeze().detach().cpu().numpy(), "lr: ", current_lr)
- if beta:
- masked_betas = []
- for model in self.relus:
- masked_betas.append(model.masked_beta)
- if model.history_beta_used:
- print(f"{model.name} history beta", model.new_history_beta.squeeze())
- if model.split_beta_used:
- print(f"{model.name} split beta:", model.split_beta.view(-1))
- print(f"{model.name} bias:", model.split_bias)
- if opt_coeffs:
- for co in coeffs:
- print(f'coeff sum: {co.abs().sum():.5g}')
- if beta and i == 0 and verbose > 0:
- breakpoint()
-
- if opt_choice == "adam-autolr":
- opt.step(lr_scale=[loss_weight, loss_weight])
- else:
- opt.step()
-
- if beta:
- # Clipping to >=0.
- for b in betas:
- b.data = (b >= 0) * b.data
- for dmi in range(len(dense_coeffs_mask)):
- # apply dense mask to the dense split coeffs matrix
- coeffs[dmi].data = dense_coeffs_mask[dmi].float() * coeffs[dmi].data
-
- if alpha:
- for m in self.relus:
- for m_start_node, v in m.alpha.items():
- v.data = torch.clamp(v.data, 0., 1.)
- # print(f'layer {m.name} start_node {m_start_node} shape {v.size()} norm {v[:,:,0].abs().sum()} {v[:,:,-1].abs().sum()} {v.abs().sum()}')
- # For tanh, we clip it in bound_ops because clipping depends. TODO: clipping should be a method in the BoundOptimizableActivation class.
- # on pre-activation bounds
-
- # If loss has become worse for some element, reset those to current best.
- with torch.no_grad():
- if beta and opt_choice == "adam-autolr" and i > iteration * 0.2:
- for ii, model in enumerate(self.relus):
- if alpha:
- # each alpha has shape (2, output_shape, batch, *shape)
- for alpha_m in model.alpha:
- model.alpha[alpha_m][:,:,worse_idx] = best_alphas[model.name][alpha_m][:,:,worse_idx].clone().detach()
- if beta and single_node_split:
- betas[ii][worse_idx] = best_betas[ii][worse_idx].clone().detach()
-
- scheduler.step()
- last_l = loss.item()
- last_total_loss = total_loss.detach().clone()
-
- # if beta and intermediate_beta_enabled and verbose > 0:
- if verbose > 0:
- breakpoint()
-
- if keep_best:
- # Set all variables to their saved best values.
- with torch.no_grad():
- for idx, model in enumerate(self.optimizable_activations):
- if alpha:
- # Assigns a new dictionary.
- model.alpha = best_alphas[model.name]
- model.inputs[0].lower.data = best_intermediate_bounds[idx][0].data
- model.inputs[0].upper.data = best_intermediate_bounds[idx][1].data
- if beta:
- if single_node_split:
- model.sparse_beta.copy_(best_betas[idx])
- else:
- for b, bb in zip(betas, best_betas):
- b.data = bb.data
- if opt_coeffs:
- for co, bco in zip(coeffs, best_coeffs):
- co.data = bco.data
- if opt_bias:
- for bias, bbias in zip(biases, best_biases):
- bias.data = bbias.data
-
- if new_interval is not None and not update_by_layer:
- for l in self._modules.values():
- if l.name in new_interval.keys() and hasattr(l, "lower"):
- # l.interval = tuple(new_interval[l.name][:2])
- l.lower = torch.max(l.lower, new_interval[l.name][0])
- l.upper = torch.min(l.upper, new_interval[l.name][1])
- infeasible_neurons = l.lower > l.upper
- if infeasible_neurons.any():
- print('infeasible!!!!!!!!!!!!!!', infeasible_neurons.sum().item(), infeasible_neurons.nonzero()[:, 0])
-
- print("best_l after optimization:", best_l.sum().item(), "with beta sum per layer:", [p.sum().item() for p in betas])
- # np.save('solve_slope.npy', np.array(record))
- print('optimal alpha/beta time:', time.time() - start)
- return best_ret
-
-
- def get_unstable_conv_locations(self, node, aux_reference_bounds):
- # For conv layer we only check the case where all neurons are active/inactive.
- unstable_masks = torch.logical_and(aux_reference_bounds[node.name][0] < 0, aux_reference_bounds[node.name][1] > 0)
- # For simplicity, merge unstable locations for all elements in this batch. TODO: use individual unstable mask.
- # It has shape (H, W) indicating if a neuron is unstable/stable.
- # TODO: so far we merge over the batch dimension to allow easier implementation.
- unstable_locs = unstable_masks.sum(dim=0).bool()
- # Now converting it to indices for these unstable nuerons. These are locations (i,j) of unstable neurons.
- unstable_idx = unstable_locs.nonzero(as_tuple=True)
- # Number of unstable neurons.
- unstable_size = unstable_idx[0].numel()
- # print(f'layer {node.name} unstable_size {unstable_size} actual {unstable_masks.sum().item()} size {node.output_shape}')
- # We sum over the channel direction, so need to multiply that.
- return unstable_idx, unstable_size
-
-
- def get_unstable_locations(self, node, aux_reference_bounds):
- # FIXME (09/19): this is for ReLU only!
- unstable_masks = torch.logical_and(aux_reference_bounds[node.name][0] < 0, aux_reference_bounds[node.name][1] > 0)
- # unstable_masks = torch.ones(dtype=torch.bool, size=(batch_size, dim), device=self.device)
- if unstable_masks.ndim > 2:
- # Flatten the conv layer shape.
- unstable_masks = unstable_masks.view(unstable_masks.size(0), -1)
- # For simplicity, merge unstable locations for all elements in this batch. TODO: use individual unstable mask.
- unstable_locs = unstable_masks.sum(dim=0).bool()
- # This is a 1-d indices, shared by all elements in this batch.
- unstable_idx = unstable_locs.nonzero().squeeze(1)
- unstable_size = unstable_idx.numel()
- # print(f'layer {node.name} unstable {unstable_size} total {node.output_shape}')
- return unstable_idx, unstable_size
-
- def compute_bounds(self, x=None, aux=None, C=None, method='backward', IBP=False, forward=False,
- bound_lower=True, bound_upper=True, reuse_ibp=False,
- return_A=False, needed_A_dict=None, final_node_name=None, average_A=False, new_interval=None,
- return_b=False, b_dict=None, reference_bounds=None, intermediate_constr=None, alpha_idx=None,
- aux_reference_bounds=None, need_A_only=False):
+ def merge_A_dict(self, lA_dict, uA_dict):
+ merged_A = {}
+ for output_node_name in lA_dict:
+ merged_A[output_node_name] = {}
+ lA_dict_ = lA_dict[output_node_name]
+ uA_dict_ = uA_dict[output_node_name]
+ for input_node_name in lA_dict_:
+ merged_A[output_node_name][input_node_name] = {
+ 'lA': lA_dict_[input_node_name]['lA'],
+ 'uA': uA_dict_[input_node_name]['uA'],
+ 'lbias': lA_dict_[input_node_name]['lbias'],
+ 'ubias': uA_dict_[input_node_name]['ubias'],
+ }
+ return merged_A
+
+ def compute_bounds(
+ self, x=None, aux=None, C=None, method='backward', IBP=False,
+ forward=False, bound_lower=True, bound_upper=True, reuse_ibp=False,
+ reuse_alpha=False, return_A=False, needed_A_dict=None,
+ final_node_name=None, average_A=False,
+ intermediate_layer_bounds=None, reference_bounds=None,
+ intermediate_constr=None, alpha_idx=None,
+ aux_reference_bounds=None, need_A_only=False,
+ cutter=None, decision_thresh=None,
+ update_mask=None):
r"""Main function for computing bounds.
Args:
- x (tuple or None): Input to the model. If it is None, the input from the last
- `forward` or `compute_bounds` call is reused. Otherwise: the number of elements in the tuple should be
- equal to the number of input nodes in the model, and each element in the tuple
- corresponds to the value for each input node respectively. It should look similar
- as the `global_input` argument when used for creating a `BoundedModule`.
-
- aux (object, optional): Auxliary information that can be passed to `Perturbation`
- classes for initializing and concretizing bounds, e.g., additional information
- for supporting synonym word subsitution perturbaiton.
-
- C (Tensor): The specification matrix that can map the output of the model with an
- additional linear layer. This is usually used for maping the logits output of the
- model to classification margins.
-
- method (str): The main method for bound computation. Choices:
+ x (tuple or None): Input to the model. If it is None, the input
+ from the last `forward` or `compute_bounds` call is reused.
+ Otherwise: the number of elements in the tuple should be
+ equal to the number of input nodes in the model, and each element in
+ the tuple corresponds to the value for each input node respectively.
+ It should look similar as the `global_input` argument when used for
+ creating a `BoundedModule`.
+
+ aux (object, optional): Auxliary information that can be passed to
+ `Perturbation` classes for initializing and concretizing bounds,
+ e.g., additional information for supporting synonym word subsitution
+ perturbaiton.
+
+ C (Tensor): The specification matrix that can map the output of the
+ model with an additional linear layer. This is usually used for
+ maping the logits output of the model to classification margins.
+
+ method (str): The main method for bound computation. Choices:
* `IBP`: purely use Interval Bound Propagation (IBP) bounds.
- * `CROWN-IBP`: use IBP to compute intermediate bounds, but use CROWN (backward mode LiRPA) to compute the bounds of the final node.
- * `CROWN`: purely use CROWN to compute bounds for intermediate nodes and the final node.
- * `Forward`: purely use forward mode LiRPA to compute the bounds.
- * `Forward+Backward`: use forward mode LiRPA to compute bounds for intermediate nodes, but further use CROWN to compute bounds for the final node.
- * `CROWN-Optimized` or `alpha-CROWN`: use CROWN, and also optimize the linear relaxation parameters for activations.
-
- IBP (bool, optional): If `True`, use IBP to compute the bounds of intermediate nodes.
- It can be automatically set according to `method`.
-
- forward (bool, optional): If `True`, use the forward mode bound propagation to compute the bounds
- of intermediate nodes. It can be automatically set according to `method`.
-
- bound_lower (bool, default `True`): If `True`, the lower bounds of the output needs to be computed.
-
- bound_upper (bool, default `True`): If `True`, the upper bounds of the output needs to be computed.
-
- reuse_ibp (bool, optional): If `True` and `method` is None, reuse the previously saved IBP bounds.
+ * `CROWN-IBP`: use IBP to compute intermediate bounds,
+ but use CROWN (backward mode LiRPA) to compute the bounds of the
+ final node.
+ * `CROWN`: purely use CROWN to compute bounds for intermediate
+ nodes and the final node.
+ * `Forward`: purely use forward mode LiRPA.
+ * `Forward+Backward`: use forward mode LiRPA for intermediate
+ nodes, but further use CROWN for the final node.
+ * `CROWN-Optimized` or `alpha-CROWN`: use CROWN, and also
+ optimize the linear relaxation parameters for activations.
+ * `forward-optimized`: use forward bounds with optimized linear
+ relaxation.
+
+ IBP (bool, optional): If `True`, use IBP to compute the bounds of
+ intermediate nodes. It can be automatically set according to
+ `method`.
+
+ forward (bool, optional): If `True`, use the forward mode bound
+ propagation to compute the bounds of intermediate nodes. It can be
+ automatically set according to `method`.
+
+ bound_lower (bool, default `True`): If `True`, the lower bounds of
+ the output needs to be computed.
+
+ bound_upper (bool, default `True`): If `True`, the upper bounds of
+ the output needs to be computed.
+
+ reuse_ibp (bool, optional): If `True` and `method` is None, reuse
+ the previously saved IBP bounds.
+
+ final_node_name (str, optional): Set the final node in the
+ computational graph for bound computation. By default, the final
+ node of the originally built computational graph is used.
+
+ return_A (bool, optional): If `True`, return linear coefficients
+ in bound propagation (`A` tensors) with `needed_A_dict` set.
+
+ needed_A_dict (dict, optional): A dictionary specifying linear
+ coefficients (`A` tensors) that are needed and should be returned.
+ Each key in the dictionary is the name of a starting node in
+ backward bound propagation, with a list as the value for the key,
+ which specifies the names of the ending nodes in backward bound
+ propagation, and the linear coefficients of the starting node w.r.t.
+ the specified ending nodes are returned. By default, it is empty.
+
+ reuse_alpha (bool, optional): If `True`, reuse previously saved
+ alpha values when they are not being optimized.
+
+ decision_thresh (float, optional): In CROWN-optimized mode, we will
+ use this decision_thresh to dynamically optimize those domains that
+ <= the threshold.
+
+ intermediate_layer_bounds: A dictionary of 2-element tuple/list
+ containing lower and upper bounds for intermediate layers.
+ The dictionary keys should include the names of the layers whose
+ bounds should be set without recomputation. The layer names can be
+ viewed by setting environment variable AUTOLIRPA_DEBUG_GRAPH=1.
+ The values of each dictionary elements are (lower_bounds,
+ upper_bounds) where "lower_bounds" and "upper_bounds" are two
+ tensors with the same shape as the output shape of this layer. If
+ you only need to set intermediate layer bounds for certain layers,
+ then just include these layers' names in the dictionary.
+
+ reference_bounds: Format is similar to "intermediate_layer_bounds".
+ However, these bounds are only used as a reference, and the bounds
+ for intermediate layers will still be computed (e.g., using CROWN,
+ IBP or other specified methods). The computed bounds will be
+ compared to "reference_bounds" and the tighter one between the two
+ will be used.
+
+ aux_reference_bounds: Format is similar to intermediate layer
+ bounds. However, these bounds are only used for determine which
+ neurons are stable and which neurons are unstable for ReLU networks.
+ Unstable neurons' intermediate layer bounds will be recomputed.
Returns:
- bound (tuple): a tuple of computed lower bound and upper bound respectively.
+ bound (tuple): When `return_A` is `False`, return a tuple of
+ the computed lower bound and upper bound. When `return_A`
+ is `True`, return a tuple of lower bound, upper bound, and
+ `A` dictionary.
"""
+ logger.debug(f'Compute bounds with {method}')
+ if needed_A_dict is None: needed_A_dict = {}
if not bound_lower and not bound_upper:
- raise ValueError('At least one of bound_lower and bound_upper must be True')
+ raise ValueError(
+ 'At least one of bound_lower and bound_upper must be True')
# Several shortcuts.
method = method.lower() if method is not None else method
@@ -1079,951 +1142,233 @@ def compute_bounds(self, x=None, aux=None, C=None, method='backward', IBP=False,
method = 'backward'
elif method == 'forward':
forward = True
- elif method == 'forward+backward':
+ elif method == 'forward+backward' or method == 'forward+crown':
method = 'backward'
forward = True
- elif method in ['crown-optimized', 'alpha-crown']:
- assert return_A is False
- ret = []
+ elif method in ['crown-optimized', 'alpha-crown', 'forward-optimized']:
+ # Lower and upper bounds need two separate rounds of optimization.
+ if method == 'forward-optimized':
+ method = 'forward'
+ else:
+ method = 'backward'
if bound_lower:
- ret1 = self.get_optimized_bounds(x=x, IBP=False, C=C, method='backward', new_interval=new_interval, reference_bounds=reference_bounds,
- bound_lower=bound_lower, bound_upper=False, return_A=return_A, aux_reference_bounds=aux_reference_bounds,
- needed_A_dict=needed_A_dict)
+ ret1 = self.get_optimized_bounds(
+ x=x, C=C, method=method,
+ intermediate_layer_bounds=intermediate_layer_bounds,
+ reference_bounds=reference_bounds, bound_lower=bound_lower,
+ bound_upper=False, return_A=return_A,
+ aux_reference_bounds=aux_reference_bounds,
+ needed_A_dict=needed_A_dict,
+ final_node_name=final_node_name,
+ cutter=cutter, decision_thresh=decision_thresh)
if bound_upper:
- ret2 = self.get_optimized_bounds(x=x, IBP=False, C=C, method='backward', new_interval=new_interval, reference_bounds=reference_bounds,
- bound_lower=False, bound_upper=bound_upper, return_A=return_A, aux_reference_bounds=aux_reference_bounds,
- needed_A_dict=needed_A_dict)
+ ret2 = self.get_optimized_bounds(
+ x=x, C=C, method=method,
+ intermediate_layer_bounds=intermediate_layer_bounds,
+ reference_bounds=reference_bounds, bound_lower=False,
+ bound_upper=bound_upper, return_A=return_A,
+ aux_reference_bounds=aux_reference_bounds,
+ needed_A_dict=needed_A_dict,
+ final_node_name=final_node_name,
+ cutter=cutter, decision_thresh=decision_thresh)
if bound_lower and bound_upper:
- return ret1[0], ret2[1]
+ if return_A:
+ # Needs to merge the A dictionary.
+ lA_dict = ret1[2]
+ uA_dict = ret2[2]
+ merged_A = self.merge_A_dict(lA_dict, uA_dict)
+ return ret1[0], ret2[1], merged_A
+ else:
+ return ret1[0], ret2[1]
elif bound_lower:
- return ret1
+ return ret1 # ret1[1] is None.
elif bound_upper:
- return ret2
+ return ret2 # ret2[0] is None.
if reference_bounds is None:
reference_bounds = {}
if aux_reference_bounds is None:
aux_reference_bounds = {}
- # If y in self.backward_node_pairs[x], then node y is visited when
+ # If y in self.backward_node_pairs[x], then node y is visited when
# doing backward bound propagation starting from node x.
self.backward_from = dict([(node, []) for node in self._modules])
if not bound_lower and not bound_upper:
- raise ValueError('At least one of bound_lower and bound_upper in compute_bounds should be True')
+ raise ValueError(
+ 'At least one of bound_lower and bound_upper in compute_bounds '
+ 'should be True')
A_dict = {} if return_A else None
if x is not None:
- self._set_input(*x, new_interval=new_interval)
+ self._set_input(
+ *x, intermediate_layer_bounds=intermediate_layer_bounds)
if IBP and method is None and reuse_ibp:
# directly return the previously saved ibp bounds
return self.ibp_lower, self.ibp_upper
- root = [self._modules[name] for name in self.root_name]
+ root = [self[name] for name in self.root_name]
batch_size = root[0].value.shape[0]
dim_in = 0
for i in range(len(root)):
value = root[i].forward()
- if hasattr(root[i], 'perturbation') and root[i].perturbation is not None:
- root[i].linear, root[i].center, root[i].aux = \
- root[i].perturbation.init(value, aux=aux, forward=forward)
- # This input/parameter has perturbation. Create an interval object.
- if self.ibp_relative:
- root[i].interval = Interval(
- None, None, root[i].linear.nominal, root[i].linear.lower_offset, root[i].linear.upper_offset)
- else:
- root[i].interval = Interval(
- root[i].linear.lower, root[i].linear.upper, ptb=root[i].perturbation)
+ if getattr(root[i], 'perturbation', None) is not None:
+ ret_init = root[i].perturbation.init(
+ value, aux=aux, forward=forward)
+ root[i].linear, root[i].center, root[i].aux = ret_init
+ # This input/parameter has perturbation.
+ # Create an interval object.
+ root[i].interval = Interval(
+ root[i].linear.lower, root[i].linear.upper,
+ ptb=root[i].perturbation)
if forward:
root[i].dim = root[i].linear.lw.shape[1]
dim_in += root[i].dim
else:
- if self.ibp_relative:
- root[i].interval = Interval(
- None, None, value, torch.zeros_like(value), torch.zeros_like(value))
- else:
- # This inpute/parameter does not has perturbation.
- # Use plain tuple defaulting to Linf perturbation.
- root[i].interval = (value, value)
- root[i].forward_value = root[i].forward_value = root[i].value = root[i].lower = root[i].upper = value
+ # This inpute/parameter does not has perturbation.
+ # Use plain tuple defaulting to Linf perturbation.
+ root[i].interval = (value, value)
+ root[i].forward_value = root[i].value = value
+ root[i].lower = root[i].upper = value
- if self.ibp_relative:
- root[i].lower, root[i].upper = root[i].interval.lower, root[i].interval.upper
- else:
- root[i].lower, root[i].upper = root[i].interval
+ root[i].lower, root[i].upper = root[i].interval
if forward:
- self._init_forward(root, dim_in)
+ self.init_forward(root, dim_in)
- final = self._modules[self.final_name] if final_node_name is None else self._modules[final_node_name]
- logger.debug('Final node {}[{}]'.format(final, final.name))
+ final = self.final_node(
+ ) if final_node_name is None else self[final_node_name]
+ logger.debug(f'Final node {final.__class__.__name__}({final.name})')
if IBP:
- res = self._IBP_general(node=final, C=C)
- if self.ibp_relative:
- self.ibp_lower, self.ibp_upper = res.lower, res.upper
- else:
- self.ibp_lower, self.ibp_upper = res
+ self.ibp_lower, self.ibp_upper = self.IBP_general(node=final, C=C)
if method is None:
- return self.ibp_lower, self.ibp_upper
+ return self.ibp_lower, self.ibp_upper
if C is None:
- # C is an identity matrix by default
+ # C is an identity matrix by default
if final.output_shape is None:
- raise ValueError('C is not provided while node {} has no default shape'.format(final.shape))
+ raise ValueError(
+ f'C is not missing while node {final} has no default shape')
dim_output = int(prod(final.output_shape[1:]))
# TODO: use an eyeC object here.
- C = torch.eye(dim_output, device=self.device).expand(batch_size, dim_output, dim_output)
+ C = torch.eye(dim_output, device=self.device).expand(
+ batch_size, dim_output, dim_output)
+
+ # Reuse previously saved alpha values,
+ # even if they are not optimized now
+ if reuse_alpha:
+ for node in self.optimizable_activations:
+ node.opt_reuse()
+ else:
+ for node in self.optimizable_activations:
+ node.opt_no_reuse()
+
+ # Inject update mask inside the activations
+ # update_mask: None or bool tensor([batch_size])
+ # If set to a tensor, only update the alpha and beta of selected
+ # element (with element=1).
+
+ if update_mask is None:
+ for node in self.optimizable_activations:
+ node.clean_alpha_beta_update_mask()
+ else:
+ for node in self.optimizable_activations:
+ node.set_alpha_beta_update_mask(update_mask)
- # check whether weights are perturbed and set nonlinear for the BoundMatMul operation
for n in self._modules.values():
- if type(n) in [BoundLinear, BoundConv, BoundBatchNormalization]:
+ # Check whether all prior intermediate bounds already exist
+ n.prior_checked = False
+ # check whether weights are perturbed and set nonlinear for the
+ # BoundMatMul operation
+ if isinstance(n, (BoundLinear, BoundConv, BoundBatchNormalization)):
n.nonlinear = False
- for l_name in n.input_name[1:]:
- node = self._modules[l_name]
+ for node in n.inputs[1:]:
if hasattr(node, 'perturbation'):
if node.perturbation is not None:
n.nonlinear = True
+ if isinstance(i, BoundRelu):
+ for node in i.inputs:
+ if isinstance(node, BoundConv):
+ # whether this Conv is followed by a ReLU
+ node.relu_followed = True
# BFS to find out whether each node is used given the current final node
- if final != self.last_final_node:
- self.last_final_node = final
+ self._set_used_nodes(final)
+
+ # FIXME clean
+ self.use_forward = forward
+ self.root = root
+ self.batch_size = batch_size
+ self.dim_in = dim_in
+ self.return_A = return_A
+ self.A_dict = A_dict
+ self.needed_A_dict = needed_A_dict
+ self.intermediate_constr = intermediate_constr
+ self.reference_bounds = reference_bounds
+ self.aux_reference_bounds = aux_reference_bounds
+ self.final_node_name = final.name
+
+ self.check_prior_bounds(final)
+
+ if method == 'backward':
+ # This is for the final output bound.
+ # No need to pass in intermediate layer beta constraints.
+ ret = self.backward_general(
+ C=C, node=final,
+ bound_lower=bound_lower, bound_upper=bound_upper,
+ average_A=average_A, need_A_only=need_A_only,
+ unstable_idx=alpha_idx, update_mask=update_mask)
+ # FIXME when C is specified, lower and upper should not be saved to
+ # final.lower and final.upper, because they are not the bounds for
+ # the node.
+ final.lower, final.upper = ret[0], ret[1]
+ return ret
+ elif method == 'forward':
+ return self.forward_general(C=C, node=final, concretize=True)
+ else:
+ raise NotImplementedError
+
+ def _set_used_nodes(self, final):
+ if final.name != self.last_final_node:
+ self.last_final_node = final.name
+ self.used_nodes = []
for i in self._modules.values():
i.used = False
final.used = True
queue = deque([final])
while len(queue) > 0:
n = queue.popleft()
- for n_pre_name in n.input_name:
- n_pre = self._modules[n_pre_name]
+ self.used_nodes.append(n)
+ for n_pre in n.inputs:
if not n_pre.used:
n_pre.used = True
queue.append(n_pre)
- for i in self._modules.values():
- if isinstance(i, BoundRelu):
- for l_name in i.input_name:
- node = self._modules[l_name]
- if isinstance(node, BoundConv):
- node.relu_followed = True # whether this Conv is followed by a ReLU
-
- for i in self._modules.values(): # for all nodes
- if not i.used:
- continue
- if hasattr(i, 'nonlinear') and i.nonlinear:
- for l_name in i.input_name:
- node = self._modules[l_name]
- if not hasattr(node, 'lower'):
- assert not IBP, 'There should be no missing intermediate bounds when IBP is enabled'
- if not node.perturbed and hasattr(node, 'forward_value'):
- node.interval = node.lower, node.upper = \
- node.forward_value, node.forward_value
- continue
- # FIXME check that weight perturbation is not affected
- # (from_input=True should be set for weights)
- if not node.from_input and hasattr(node, 'forward_value'):
- node.lower = node.upper = node.forward_value
- continue
- if forward:
- l, u = self._forward_general(
- node=node, root=root, dim_in=dim_in, concretize=True)
- else:
- # assign concretized bound for ReLU layer to save computational cost
- # FIXME: Put ReLU after reshape will cause problem!
- if (isinstance(node, BoundActivation) or isinstance(node, BoundTranspose)) and hasattr(
- self._modules[node.input_name[0]], 'lower'):
- node.lower = node.forward(self._modules[node.input_name[0]].lower)
- node.upper = node.forward(self._modules[node.input_name[0]].upper)
- elif isinstance(node, BoundReshape) and \
- hasattr(self._modules[node.input_name[0]], 'lower') and \
- hasattr(self._modules[node.input_name[1]], 'value'):
- # Node for input value.
- val_input = self._modules[node.input_name[0]]
- # Node for input parameter (e.g., shape, permute)
- arg_input = self._modules[node.input_name[1]]
- node.lower = node.forward(val_input.lower, arg_input.value)
- node.upper = node.forward(val_input.upper, arg_input.value)
- else:
- first_layer_flag = False
- # This is the list of all intermediate layers where we need to refine.
- if intermediate_constr is not None:
- intermediate_beta_enabled_layers = [k for v in intermediate_constr.values() for k in v]
- else:
- intermediate_beta_enabled_layers = []
- # Here we avoid creating a big C matrix in the first linear layer.
- # Disable this optimization when we have beta for intermediate layer bounds.
- if type(node) == BoundLinear or type(node) == BoundConv and node.name not in intermediate_beta_enabled_layers:
- for l_pre in node.input_name:
- if type(self._modules[l_pre]) == BoundInput:
- node.lower, node.upper = self._IBP_general(node)
- first_layer_flag = True
- break
- if not first_layer_flag:
- reduced_dim = False # Only partial neurons (unstable neurons) are bounded.
- unstable_idx = None
- unstable_size = 99999
- dim = int(prod(node.output_shape[1:]))
- sparse_intermediate_bounds = node.name in aux_reference_bounds and self.bound_opts.get('sparse_intermediate_bounds', False) and isinstance(self._modules[node.output_name[0]], BoundRelu)
- sparse_conv_intermediate_bounds = self.bound_opts.get('sparse_conv_intermediate_bounds', False)
- # FIXME: C matrix shape incorrect for BoundParams.
- if (isinstance(node, BoundLinear) or isinstance(node, BoundMatMul)) and int(
- os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0:
- if sparse_intermediate_bounds:
- # If we are doing bound refinement and reference bounds are given, we only refine unstable neurons.
- # Also, if we are checking against LP solver we will refine all neurons and do not use this optimization.
- # For each batch element, we find the unstable neurons.
- unstable_idx, unstable_size = self.get_unstable_locations(node, aux_reference_bounds)
- if unstable_size == 0:
- # Do nothing, no bounds will be computed.
- reduced_dim = True
- unstable_idx = []
- elif unstable_size < 0.9 * dim and unstable_size > 0:
- # Create an abstract C matrix, the unstable_idx are the non-zero elements in specifications for all batches.
- newC = OneHotC([batch_size, unstable_size, *node.output_shape[1:]], self.device, unstable_idx, None)
- reduced_dim = True
- else:
- unstable_idx = None
- if not reduced_dim:
- newC = eyeC([batch_size, dim, *node.output_shape[1:]], self.device)
- elif (isinstance(node, BoundConv) or isinstance(node,
- BoundBatchNormalization)) and node.mode == "patches":
- if sparse_intermediate_bounds:
- unstable_idx, unstable_size = self.get_unstable_conv_locations(node, aux_reference_bounds)
- if unstable_size == 0:
- # Do nothing, no bounds will be computed.
- reduced_dim = True
- unstable_idx = []
- # We sum over the channel direction, so need to multiply that.
- elif sparse_conv_intermediate_bounds and unstable_size < 0.8 * dim:
- # Create an abstract C matrix, the unstable_idx are the non-zero elements in specifications for all batches.
- # The shape of patches is [unstable_size, batch, C, H, W].
- newC = Patches(patches=None, stride=1, padding=0, shape=[
- unstable_size, batch_size, node.output_shape[-3], 1, 1],
- identity=1, unstable_idx=unstable_idx, output_shape=node.output_shape)
- reduced_dim = True
- else:
- unstable_idx = None
- # Here we create an Identity Patches object
- if not reduced_dim:
- newC = Patches(None, 1, 0,
- [node.output_shape[-3], batch_size, node.output_shape[-2], node.output_shape[-1],
- node.output_shape[-3], 1, 1], 1, output_shape=node.output_shape)
- elif isinstance(node, BoundAdd) and node.mode == "patches":
- if sparse_intermediate_bounds:
- unstable_idx, unstable_size = self.get_unstable_conv_locations(node, aux_reference_bounds)
- if unstable_size == 0:
- # Do nothing, no bounds will be computed.
- reduced_dim = True
- unstable_idx = []
- elif sparse_conv_intermediate_bounds and unstable_size < 0.8 * dim:
- num_channel = node.output_shape[-3]
- # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1).
- patches = (torch.eye(num_channel, device=self.device)).view(num_channel, 1, 1, 1, num_channel, 1, 1)
- # Expand to (out_c, 1, unstable_size, out_c, 1, 1).
- patches = patches.expand(-1, 1, node.output_shape[-2], node.output_shape[-1], -1, 1, 1)
- patches = patches[unstable_idx[0], :, unstable_idx[1], unstable_idx[2]]
- # Expand with the batch dimension. Final shape (unstable_size, batch_size, out_c, 1, 1).
- patches = patches.expand(-1, batch_size, -1, -1, -1)
- newC = Patches(patches, 1, 0, patches.shape, unstable_idx=unstable_idx, output_shape=node.output_shape)
- reduced_dim = True
- else:
- unstable_idx = None
- if not reduced_dim:
- num_channel = node.output_shape[-3]
- # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1).
- patches = (torch.eye(num_channel, device=self.device)).view(num_channel, 1, 1, 1, num_channel, 1, 1)
- # Expand to (out_c, batch, out_h, out_w, out_c, 1, 1).
- patches = patches.expand(-1, batch_size, node.output_shape[-2], node.output_shape[-1], -1, 1, 1)
- newC = Patches(patches, 1, 0, patches.shape, output_shape=node.output_shape)
- else:
- if sparse_intermediate_bounds:
- unstable_idx, unstable_size = self.get_unstable_locations(node, aux_reference_bounds)
- if unstable_size == 0:
- # Do nothing, no bounds will be computed.
- reduced_dim = True
- unstable_idx = []
- # Number of unstable neurons after merging.
- elif unstable_size < 0.9 * dim:
- # Create a C matrix.
- newC = torch.zeros([1, unstable_size, dim], device=self.device)
- # Fill the corresponding elements to 1.0
- newC[0, torch.arange(unstable_size), unstable_idx] = 1.0
- newC = newC.expand(batch_size, -1, -1).view(batch_size, unstable_size, *node.output_shape[1:])
- reduced_dim = True
- # print(f'layer {node.name} total {dim} unstable {unstable_size} newC {newC.size()}')
- else:
- unstable_idx = None
- if not reduced_dim:
- if dim > 1000:
- warnings.warn(f"Creating an identity matrix with size {dim}x{dim} for node {node}. This may indicate poor performance for bound computation. If you see this message on a small network please submit a bug report.", stacklevel=2)
- newC = torch.eye(dim, device=self.device) \
- .unsqueeze(0).expand(batch_size, -1, -1) \
- .view(batch_size, dim, *node.output_shape[1:])
- if unstable_idx is None or unstable_size > 0:
- # Compute backward bounds only when there are unstable neurons, or when we don't know which neurons are unstable.
- self._backward_general(C=newC, node=node, root=root, return_A=False, intermediate_constr=intermediate_constr, unstable_idx=unstable_idx)
-
- if reduced_dim:
- if unstable_size > 0:
- # If we only calculated unstable neurons, we need to scatter the results back based on reference bounds.
- if isinstance(unstable_idx, tuple):
- new_lower = aux_reference_bounds[node.name][0].detach().clone()
- new_upper = aux_reference_bounds[node.name][1].detach().clone()
- # Conv layer with patches, the unstable_idx is a 3-element tuple for 3 indices (C, H,W) of unstable neurons.
- new_lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = node.lower
- new_upper[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = node.upper
- else:
- # Other layers.
- new_lower = aux_reference_bounds[node.name][0].detach().clone().view(batch_size, -1)
- new_upper = aux_reference_bounds[node.name][1].detach().clone().view(batch_size, -1)
- new_lower[:, unstable_idx] = node.lower.view(batch_size, -1)
- new_upper[:, unstable_idx] = node.upper.view(batch_size, -1)
- # print(f'{node.name} {node} bound diff {(new_lower.view(-1) - aux_reference_bounds[node.name][0].view(-1)).abs().sum()} {(new_upper.view(-1) - aux_reference_bounds[node.name][1].view(-1)).abs().sum()}')
- node.lower = new_lower.view(batch_size, *node.output_shape[1:])
- node.upper = new_upper.view(batch_size, *node.output_shape[1:])
- else:
- # No unstable neurons. Skip the update.
- node.lower = aux_reference_bounds[node.name][0].detach().clone()
- node.upper = aux_reference_bounds[node.name][1].detach().clone()
- # node.lower and node.upper (intermediate bounds) are computed in the above function.
- # If we have bound references, we set them here to always obtain a better set of bounds.
- if node.name in reference_bounds:
- # Initially, the reference bound and the computed bound can be exactly the same when intermediate layer beta is 0. This will prevent gradients flow. So we need a small guard here.
- if intermediate_constr is not None:
- # Intermediate layer beta is used.
- # Note that we cannot just take the reference bounds if they are better - this makes alphas have zero gradients.
- node.lower = torch.max((0.9 * reference_bounds[node.name][0] + 0.1 * node.lower), node.lower)
- node.upper = torch.min((0.9 * reference_bounds[node.name][1] + 0.1 * node.upper), node.upper)
- # Additionally, if the reference bounds say a neuron is stable, we always keep it. (FIXME: this is for ReLU only).
- lower_stable = reference_bounds[node.name][0] >= 0.
- node.lower[lower_stable] = reference_bounds[node.name][0][lower_stable]
- upper_stable = reference_bounds[node.name][1] <= 0.
- node.upper[upper_stable] = reference_bounds[node.name][1][upper_stable]
- else:
- # MIP solved intermediate layer bounds.
- # Set the intermediate layer bounds using reference bounds, always choosing the tighter one.
- node.lower = torch.max(reference_bounds[node.name][0] - 1e-5, node.lower)
- node.upper = torch.min(reference_bounds[node.name][1] + 1e-5, node.upper)
- # Otherwise, we only use reference bounds to check which neurons are unstable.
-
- if method == 'backward':
- # This is for the final output bound. No need to pass in intermediate layer beta constraints.
- return self._backward_general(C=C, node=final, root=root, bound_lower=bound_lower, bound_upper=bound_upper,
- return_A=return_A, needed_A_dict=needed_A_dict, average_A=average_A, A_dict=A_dict,
- return_b=return_b, b_dict=b_dict, unstable_idx=alpha_idx, need_A_only=need_A_only)
- elif method == 'forward':
- return self._forward_general(C=C, node=final, root=root, dim_in=dim_in, concretize=True)
- else:
- raise NotImplementedError
-
- """ improvement on merging BoundLinear, BoundGatherElements and BoundSub
- when loss fusion is used in training"""
- def _IBP_loss_fusion(self, node, C):
- # not using loss fusion
- if not self.bound_opts.get('loss_fusion', False):
- return None
-
- # Currently this function has issues in more complicated networks.
- if self.bound_opts.get('no_ibp_loss_fusion', False):
- return None
-
- if C is None and isinstance(node, BoundSub):
- node_gather = self._modules[node.input_name[1]]
- if isinstance(node_gather, BoundGatherElements) or isinstance(node_gather, BoundGatherAten):
- node_linear = self._modules[node.input_name[0]]
- node_start = self._modules[node_linear.input_name[0]]
- if isinstance(node_linear, BoundLinear):
- w = self._modules[node_linear.input_name[1]].param
- b = self._modules[node_linear.input_name[2]].param
- labels = self._modules[node_gather.input_name[1]]
- if not hasattr(node_start, 'interval'):
- self._IBP_general(node_start)
- for inp in node_gather.input_name:
- n = self._modules[inp]
- if not hasattr(n, 'interval'):
- self._IBP_general(n)
- if torch.isclose(labels.lower, labels.upper, 1e-8).all():
- labels = labels.lower
- batch_size = labels.shape[0]
- w = w.expand(batch_size, *w.shape)
- w = w - torch.gather(w, dim=1,
- index=labels.unsqueeze(-1).repeat(1, w.shape[1], w.shape[2]))
- b = b.expand(batch_size, *b.shape)
- b = b - torch.gather(b, dim=1,
- index=labels.repeat(1, b.shape[1]))
- lower, upper = node_start.interval
- lower, upper = lower.unsqueeze(1), upper.unsqueeze(1)
- node.lower, node.upper = node_linear.interval_propagate(
- (lower, upper), (w, w), (b.unsqueeze(1), b.unsqueeze(1)))
- node.interval = node.lower, node.upper = node.lower.squeeze(1), node.upper.squeeze(1)
- return node.interval
- return None
-
- def _IBP_general(self, node=None, C=None):
- if self.bound_opts.get('loss_fusion', False):
- res = self._IBP_loss_fusion(node, C)
- if res is not None:
- return res
-
- if not node.perturbed and hasattr(node, 'forward_value'):
- node.lower, node.upper = node.interval = (node.forward_value, node.forward_value)
- if self.ibp_relative:
- node.interval = Interval(
- None, None,
- nominal=node.forward_value,
- lower_offset=torch.zeros_like(node.forward_value),
- upper_offset=torch.zeros_like(node.forward_value))
-
- if not hasattr(node, 'interval'):
- for n in node.inputs:
- if not hasattr(n, 'interval'):
- self._IBP_general(n)
- inp = [n_pre.interval for n_pre in node.inputs]
- if C is not None and isinstance(node, BoundLinear) and not node.is_input_perturbed(1):
- # merge the last BoundLinear node with the specification, available when
- # weights of this layer are not perturbed
- return node.interval_propagate(*inp, C=C)
- else:
- node.interval = node.interval_propagate(*inp)
-
- if self.ibp_relative:
- node.lower, node.upper = node.interval.lower, node.interval.upper
- else:
- node.lower, node.upper = node.interval
- if isinstance(node.lower, torch.Size):
- node.lower = torch.tensor(node.lower)
- node.interval = (node.lower, node.upper)
- if isinstance(node.upper, torch.Size):
- node.upper = torch.tensor(node.upper)
- node.interval = (node.lower, node.upper)
-
- if C is not None:
- return BoundLinear.interval_propagate(None, node.interval, C=C)
- else:
- return node.interval
-
- def _addA(self, A1, A2):
- """ Add two A (each of them is either Tensor or Patches) """
- if type(A1) == torch.Tensor and type(A1) == torch.Tensor:
- return A1 + A2
- elif type(A1) == Patches and type(A2) == Patches:
- # Here we have to merge two patches, and if A1.stride != A2.stride, the patches will become a matrix,
- # in this case, we will avoid using this mode
- assert A1.stride == A2.stride, "A1.stride should be the same as A2.stride, otherwise, please use the matrix mode"
- if A1.unstable_idx is not None or A2.unstable_idx is not None:
- if A1.unstable_idx is not A2.unstable_idx: # Same tuple object.
- raise ValueError('Please set bound option "sparse_conv_intermediate_bounds" to False to run this model.')
- assert A1.output_shape == A2.output_shape
- # change paddings to merge the two patches
- if A1.padding != A2.padding:
- if A1.padding > A2.padding:
- A2 = A2._replace(patches=F.pad(A2.patches, (
- A1.padding - A2.padding, A1.padding - A2.padding, A1.padding - A2.padding,
- A1.padding - A2.padding)))
- else:
- A1 = A1._replace(patches=F.pad(A1.patches, (
- A2.padding - A1.padding, A2.padding - A1.padding, A2.padding - A1.padding,
- A2.padding - A1.padding)))
- sum_ret = A1.patches + A2.patches
- return Patches(sum_ret, A2.stride, max(A1.padding, A2.padding), sum_ret.shape, unstable_idx=A1.unstable_idx, output_shape=A1.output_shape)
- else:
- if type(A1) == Patches:
- pieces = A1.patches
- stride = A1.stride
- padding = A1.padding
- patch_output_shape = A1.output_shape
- patch_unstable_idx = A1.unstable_idx
- # Patches has shape (out_c, batch, out_h, out_w, in_c, h, w).
- input_shape = A2.shape[3:]
- matrix = A2
- if type(A2) == Patches:
- pieces = A2.patches
- stride = A2.stride
- padding = A2.padding
- patch_output_shape = A2.output_shape
- patch_unstable_idx = A2.unstable_idx
- input_shape = A1.shape[3:]
- matrix = A1
- A1_matrix = patches_to_matrix(
- pieces, input_shape, stride, padding, output_shape=patch_output_shape, unstable_idx=patch_unstable_idx)
- return A1_matrix.transpose(0,1) + matrix
-
- def _backward_general(self, C=None, node=None, root=None, bound_lower=True, bound_upper=True,
- return_A=False, needed_A_dict=None, average_A=False, A_dict=None, return_b=False, b_dict=None,
- intermediate_constr=None, unstable_idx=None, need_A_only=False):
- logger.debug('Backward from ({})[{}]'.format(node, node.name))
- _print_time = False
-
- degree_out = {}
- for l in self._modules.values():
- l.bounded = True
- l.lA = l.uA = None
- degree_out[l.name] = 0
- queue = deque([node])
- all_nodes_before = []
- while len(queue) > 0:
- l = queue.popleft()
- self.backward_from[l.name].append(node)
- for l_pre in l.input_name:
- all_nodes_before.append(l_pre)
- degree_out[l_pre] += 1 # calculate the out degree
- if self._modules[l_pre].bounded:
- self._modules[l_pre].bounded = False
- queue.append(self._modules[l_pre])
- node.bounded = True
- if isinstance(C, Patches):
- if C.unstable_idx is None:
- # Patches have size (out_c, batch, out_h, out_w, c, h, w).
- out_c, batch_size, out_h, out_w = C.shape[:4]
- output_dim = out_c * out_h * out_w
- else:
- # Patches have size (unstable_size, batch, c, h, w).
- output_dim, batch_size = C.shape[:2]
- else:
- batch_size, output_dim = C.shape[:2]
-
- # The C matrix specified by the user has shape (batch, spec) but internally we have (spec, batch) format.
- if not isinstance(C, (eyeC, Patches, OneHotC)):
- C = C.transpose(0, 1)
- elif isinstance(C, eyeC):
- C = C._replace(shape=(C.shape[1], C.shape[0], *C.shape[2:]))
- elif isinstance(C, OneHotC):
- C = C._replace(shape=(C.shape[1], C.shape[0], *C.shape[2:]), index=C.index.transpose(0,-1), coeffs=None if C.coeffs is None else C.coeffs.transpose(0,-1))
-
- node.lA = C if bound_lower else None
- node.uA = C if bound_upper else None
- lb = ub = torch.tensor(0., device=self.device)
-
-
- # Save intermediate layer A matrices when required.
- A_record = {}
-
- queue = deque([node])
- while len(queue) > 0:
- l = queue.popleft() # backward from l
- l.bounded = True
-
- if return_b:
- b_dict[l.name] = { 'lower_b': lb, 'upper_b': ub }
-
- if l.name in self.root_name or l == root: continue
-
- for l_pre in l.input_name: # if all the succeeds are done, then we can turn to this node in the next iteration.
- _l = self._modules[l_pre]
- degree_out[l_pre] -= 1
- if degree_out[l_pre] == 0:
- queue.append(_l)
-
- # Initially, l.lA or l.uA will be set to C for this node.
- if l.lA is not None or l.uA is not None:
- # Propagate lA and uA to a preceding node
- def add_bound(node, lA, uA):
- if lA is not None:
- if node.lA is None:
- # First A added to this node.
- node.zero_lA_mtx = l.zero_backward_coeffs_l
- node.lA = lA
- else:
- node.zero_lA_mtx = node.zero_lA_mtx and l.zero_backward_coeffs_l
- node.lA = self._addA(node.lA, lA)
- if uA is not None:
- if node.uA is None:
- # First A added to this node.
- node.zero_uA_mtx = l.zero_backward_coeffs_u
- node.uA = uA
- else:
- node.zero_uA_mtx = node.zero_uA_mtx and l.zero_backward_coeffs_u
- node.uA = self._addA(node.uA, uA)
-
- if _print_time:
- start_time = time.time()
-
- # FIXME make fixed nodes have fixed `forward_value` that is never cleaned out
- if not l.perturbed and hasattr(l, 'forward_value'):
- lb = lb + l.get_bias(l.lA, l.forward_value) # FIXME (09/16): shape for the bias of BoundConstant.
- ub = ub + l.get_bias(l.uA, l.forward_value)
- continue
-
- if l.zero_uA_mtx and l.zero_lA_mtx:
- # A matrices are all zero, no need to propagate.
- continue
-
- if isinstance(l, BoundRelu):
- A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *l.inputs, start_node=node, unstable_idx=unstable_idx,
- beta_for_intermediate_layers=intermediate_constr is not None) # TODO: unify this interface.
- elif isinstance(l, BoundOptimizableActivation):
- A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *l.inputs,
- start_shape=(prod(node.output_shape[1:]) if node.name != self.final_name
- else C.shape[0]), start_node=node)
- else:
- A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *l.inputs)
-
- if _print_time:
- time_elapsed = time.time() - start_time
- if time_elapsed > 1e-3:
- print(l, time_elapsed)
- if lb.ndim > 0 and type(lower_b) == torch.Tensor and self.conv_mode == 'patches':
- # When we use patches mode, it's possible that we need to add two bias
- # one is from the Tensor mode and one is from the patches mode
- # And we need to detect this case and reshape the bias
- assert lower_b.ndim == 4 or lower_b.ndim == 2
- assert lb.ndim == 4 or lb.ndim == 2
- if lower_b.ndim < lb.ndim:
- lb = lb.transpose(0,1).reshape(lb.size(1), lb.size(0), -1)
- lb = lb.expand(lb.size(0), lb.size(1), lower_b.size(0)//lb.size(1))
- lb = lb.reshape(lb.size(0), -1).t()
- ub = ub.transpose(0,1).reshape(ub.size(1), ub.size(0), -1)
- ub = ub.expand(ub.size(0), ub.size(1), upper_b.size(0)//ub.size(1))
- ub = ub.reshape(ub.size(0), -1).t()
- elif lower_b.ndim > lb.ndim:
- lower_b = lower_b.transpose(0,1).reshape(lower_b.size(1), -1).t()
- upper_b = upper_b.transpose(0,1).reshape(upper_b.size(1), -1).t()
- lb = lb + lower_b
- ub = ub + upper_b
- if return_A and needed_A_dict and node.name in needed_A_dict:
- # FIXME ???
- # if isinstance(self._modules[l.output_name[0]], BoundRelu):
- # We save the A matrices after propagating through layer l if a ReLU follows l.
- # Here we saved the A *after* propagating backwards through this layer.
- # Note that we return the accumulated bias terms, to maintain linear relationship of this node (TODO: does not support ResNet).
- if l.name in needed_A_dict[node.name]:
- A_record.update({l.name: {
- "lA": A[0][0].transpose(0, 1).detach() if A[0][0] is not None else None,
- "uA": A[0][1].transpose(0, 1) .detach()if A[0][1] is not None else None,
- "lbias": lb.transpose(0, 1).detach() if lb.ndim > 1 else None,
- # When not used, lb or ub is tensor(0).
- "ubias": ub.transpose(0, 1).detach() if ub.ndim > 1 else None,
- }})
- A_dict.update({node.name: A_record})
- if need_A_only and set(needed_A_dict[node.name]) == set(A_record.keys()):
- # We have collected all A matrices we need. We can return now!
- A_dict.update({node.name: A_record})
- # Do not concretize to save time. We just need the A matrices.
- # return A matrix as a dict: {node.name: [A_lower, A_upper]}
- return None, None, A_dict
-
-
- for i, l_pre in enumerate(l.input_name):
- _l = self._modules[l_pre]
- add_bound(_l, lA=A[i][0], uA=A[i][1])
-
- if lb.ndim >= 2:
- lb = lb.transpose(0, 1)
- if ub.ndim >= 2:
- ub = ub.transpose(0, 1)
-
- if return_A and needed_A_dict and node.name in needed_A_dict:
- root_A_record = {}
- for i in range(len(root)):
- if root[i].lA is None and root[i].uA is None: continue
- if root[i].name in needed_A_dict[node.name]:
- root_A_record.update({root[i].name: {
- "lA": root[i].lA.transpose(0, 1).detach() if root[i].lA is not None else None,
- "uA": root[i].uA.transpose(0, 1).detach() if root[i].uA is not None else None,
- }})
- root_A_record.update(A_record) # merge to existing A_record
- A_dict.update({node.name: root_A_record})
-
- for i in range(len(root)):
- if root[i].lA is None and root[i].uA is None: continue
- if average_A and isinstance(root[i], BoundParams):
- lA = root[i].lA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].lA.shape) if bound_lower else None
- uA = root[i].uA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].uA.shape) if bound_upper else None
- else:
- lA, uA = root[i].lA, root[i].uA
-
- if not isinstance(root[i].lA, eyeC) and not isinstance(root[i].lA, Patches):
- lA = root[i].lA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_lower else None
- if not isinstance(root[i].uA, eyeC) and not isinstance(root[i].uA, Patches):
- uA = root[i].uA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_upper else None
- if hasattr(root[i], 'perturbation') and root[i].perturbation is not None:
- if isinstance(root[i], BoundParams):
- # add batch_size dim for weights node
- lb = lb + root[i].perturbation.concretize(
- root[i].center.unsqueeze(0), lA,
- sign=-1, aux=root[i].aux) if bound_lower else None
- ub = ub + root[i].perturbation.concretize(
- root[i].center.unsqueeze(0), uA,
- sign=+1, aux=root[i].aux) if bound_upper else None
- else:
- lb = lb + root[i].perturbation.concretize(root[i].center, lA, sign=-1,
- aux=root[i].aux) if bound_lower else None
- ub = ub + root[i].perturbation.concretize(root[i].center, uA, sign=+1,
- aux=root[i].aux) if bound_upper else None
- else:
- if i < self.num_global_inputs:
- # Input node so there is a batch dimension
- fv = root[i].forward_value.view(batch_size, -1, 1)
- batch_size_ = batch_size
- else:
- # Parameter node so there is no batch dimension
- fv = root[i].forward_value.view(-1, 1)
- batch_size_ = 1
- if isinstance(lA, eyeC):
- lb = lb + fv.view(batch_size_, -1) if bound_lower else None
- else:
- lb = lb + lA.matmul(fv).squeeze(-1) if bound_lower else None
- if isinstance(uA, eyeC):
- ub = ub + fv.view(batch_size_, -1) if bound_upper else None
- else:
- ub = ub + uA.matmul(fv).squeeze(-1) if bound_upper else None
-
- if isinstance(C, Patches) and C.unstable_idx is not None:
- # Sparse patches; the output shape is (unstable_size, ).
- output_shape = [C.shape[0]]
- elif prod(node.output_shape[1:]) != output_dim and not isinstance(C, Patches):
- # For the output node, the shape of the bound follows C
- # instead of the original output shape
- # TODO Maybe don't set node.lower and node.upper in this case?
- # Currently some codes still depend on node.lower and node.upper
- output_shape = [-1]
- else:
- # Generally, the shape of the bounds match the output shape of the node
- output_shape = node.output_shape[1:]
- lb = node.lower = lb.view(batch_size, *output_shape) if bound_lower else None
- ub = node.upper = ub.view(batch_size, *output_shape) if bound_upper else None
-
- if return_A: return lb, ub, A_dict
-
- return lb, ub
-
- def _forward_general(self, C=None, node=None, root=None, dim_in=None, concretize=False):
- if hasattr(node, 'lower'):
- return node.lower, node.upper
-
- if not node.from_input:
- w, b = None, node.value
- node.linear = LinearBound(w, b, w, b, b, b)
- node.lower = node.upper = b
- node.interval = (node.lower, node.upper)
- return node.interval
-
- if not hasattr(node, 'linear'):
- for l_pre in node.input_name:
- l = self._modules[l_pre]
- if not hasattr(l, 'linear'):
- self._forward_general(node=l, root=root, dim_in=dim_in)
-
- inp = [self._modules[l_pre].linear for l_pre in node.input_name]
-
- if C is not None and isinstance(node, BoundLinear) and not node.is_input_perturbed(1):
- node.linear = node.bound_forward(dim_in, *inp, C=C)
- C_merged = True
- else:
- node.linear = node.bound_forward(dim_in, *inp)
- C_merged = False
-
- lw, uw = node.linear.lw, node.linear.uw
- lower, upper = node.linear.lb, node.linear.ub
-
- if C is not None and not C_merged:
- # FIXME use bound_forward of BoundLinear
- C_pos, C_neg = C.clamp(min=0), C.clamp(max=0)
- _lw = torch.matmul(lw, C_pos.transpose(-1, -2)) + torch.matmul(uw, C_neg.transpose(-1, -2))
- _uw = torch.matmul(uw, C_pos.transpose(-1, -2)) + torch.matmul(lw, C_neg.transpose(-1, -2))
- lw, uw = _lw, _uw
- _lower = torch.matmul(lower.unsqueeze(1), C_pos.transpose(-1, -2)) + \
- torch.matmul(upper.unsqueeze(1), C_neg.transpose(-1, -2))
- _upper = torch.matmul(upper.unsqueeze(1), C_pos.transpose(-1, -2)) + \
- torch.matmul(lower.unsqueeze(1), C_neg.transpose(-1, -2))
- lower, upper = _lower.squeeze(1), _upper.squeeze(1)
- else:
- lw, uw = node.linear.lw, node.linear.uw
- lower, upper = node.linear.lb, node.linear.ub
-
- if concretize:
- if node.linear.lw is not None:
- prev_dim_in = 0
- batch_size = lw.shape[0]
- assert (lw.ndim > 1)
- lA = lw.reshape(batch_size, dim_in, -1).transpose(1, 2)
- uA = uw.reshape(batch_size, dim_in, -1).transpose(1, 2)
- for i in range(len(root)):
- if hasattr(root[i], 'perturbation') and root[i].perturbation is not None:
- _lA = lA[:, :, prev_dim_in : (prev_dim_in + root[i].dim)]
- _uA = uA[:, :, prev_dim_in : (prev_dim_in + root[i].dim)]
- lower = lower + root[i].perturbation.concretize(
- root[i].center, _lA, sign=-1, aux=root[i].aux).view(lower.shape)
- upper = upper + root[i].perturbation.concretize(
- root[i].center, _uA, sign=+1, aux=root[i].aux).view(upper.shape)
- prev_dim_in += root[i].dim
- if C is None:
- node.linear = node.linear._replace(lower=lower, upper=upper)
- if C is None:
- node.lower, node.upper = lower, upper
- if not Benchmarking and torch.isnan(lower).any():
- import pdb
- pdb.set_trace()
- return lower, upper
-
- def _init_forward(self, root, dim_in):
- if dim_in == 0:
- raise ValueError("At least one node should have a specified perturbation")
- prev_dim_in = 0
- # Assumption: root[0] is the input node which implies batch_size
- batch_size = root[0].value.shape[0]
- for i in range(len(root)):
- if hasattr(root[i], 'perturbation') and root[i].perturbation is not None:
- shape = root[i].linear.lw.shape
- device = root[i].linear.lw.device
- dtype = root[i].linear.lw.dtype
- root[i].linear = root[i].linear._replace(
- lw=torch.cat([
- torch.zeros(shape[0], prev_dim_in, *shape[2:], device=device, dtype=dtype),
- root[i].linear.lw,
- torch.zeros(shape[0], dim_in - shape[1], *shape[2:], device=device, dtype=dtype)
- ], dim=1),
- uw=torch.cat([
- torch.zeros(shape[0], prev_dim_in, *shape[2:], device=device, dtype=dtype),
- root[i].linear.uw,
- torch.zeros(shape[0], dim_in - shape[1] - prev_dim_in, *shape[2:], device=device, dtype=dtype)
- ], dim=1)
- )
- if i >= self.num_global_inputs:
- root[i].forward_value = root[i].forward_value.unsqueeze(0).repeat(
- *([batch_size] + [1] * self.forward_value.ndim))
- prev_dim_in += shape[1]
- else:
- b = fv = root[i].forward_value
- shape = fv.shape
- if root[i].from_input:
- w = torch.zeros(shape[0], dim_in, *shape[1:], device=self.device)
- else:
- w = None
- root[i].linear = LinearBound(w, b, w, b, b, b)
- root[i].lower = root[i].upper = b
- root[i].interval = (root[i].lower, root[i].upper)
-
- """Add perturbation to an intermediate node and it is treated as an independent
- node in bound computation."""
-
def add_intermediate_perturbation(self, node, perturbation):
+ """Add perturbation to an intermediate node and it is treated as an
+ independent node in bound computation."""
node.perturbation = perturbation
node.perturbed = True
# NOTE This change is currently inreversible
if not node.name in self.root_name:
self.root_name.append(node.name)
-
-
-class BoundDataParallel(DataParallel):
- # https://github.com/huanzhang12/CROWN-IBP/blob/master/bound_layers.py
- # This is a customized DataParallel class for our project
- def __init__(self, *inputs, **kwargs):
- super(BoundDataParallel, self).__init__(*inputs, **kwargs)
- self._replicas = None
-
- # Overide the forward method
- def forward(self, *inputs, **kwargs):
- disable_multi_gpu = False # forward by single GPU
- no_replicas = False # forward by multi GPUs but without replicate
- if "disable_multi_gpu" in kwargs:
- disable_multi_gpu = kwargs["disable_multi_gpu"]
- kwargs.pop("disable_multi_gpu")
-
- if "no_replicas" in kwargs:
- no_replicas = kwargs["no_replicas"]
- kwargs.pop("no_replicas")
-
- if not self.device_ids or disable_multi_gpu:
- if kwargs.pop("get_property", False):
- return self.get_property(self, *inputs, **kwargs)
- return self.module(*inputs, **kwargs)
-
- if kwargs.pop("get_property", False):
- if self._replicas is None:
- assert 0, 'please call IBP/CROWN before get_property'
- if len(self.device_ids) == 1:
- return self.get_property(self.module, **kwargs)
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
- kwargs = list(kwargs)
- for i in range(len(kwargs)):
- kwargs[i]['model'] = self._replicas[i]
- outputs = self.parallel_apply([self.get_property] * len(kwargs), inputs, kwargs)
- return self.gather(outputs, self.output_device)
-
- # Only replicate during forward/IBP propagation. Not during interval bounds
- # and CROWN-IBP bounds, since weights have not been updated. This saves 2/3
- # of communication cost.
- if not no_replicas:
- if self._replicas is None: # first time
- self._replicas = self.replicate(self.module, self.device_ids)
- elif kwargs.get("method_opt", "forward") == "forward":
- self._replicas = self.replicate(self.module, self.device_ids)
- elif kwargs.get("x") is not None and kwargs.get("IBP") is True: #
- self._replicas = self.replicate(self.module, self.device_ids)
- # Update the input nodes to the ones within each replica respectively
- for bounded_module in self._replicas:
- for node in bounded_module._modules.values():
- node.inputs = [bounded_module._modules[name] for name in node.input_name]
-
- for t in chain(self.module.parameters(), self.module.buffers()):
- if t.device != self.src_device_obj:
- raise RuntimeError("module must have its parameters and buffers "
- "on device {} (device_ids[0]) but found one of "
- "them on device: {}".format(self.src_device_obj, t.device))
-
- # TODO: can be done in parallel, only support same ptb for all inputs per forward/IBP propagation
- if len(inputs) > 0 and hasattr(inputs[0], 'ptb') and inputs[0].ptb is not None:
- # compute bounds without x
- # inputs_scatter is a normal tensor, we need to assign ptb to it if inputs is a BoundedTensor
- inputs_scatter, kwargs = self.scatter((inputs, inputs[0].ptb.x_L, inputs[0].ptb.x_U), kwargs,
- self.device_ids)
- # inputs_scatter = inputs_scatter[0]
- bounded_inputs = []
- for input_s in inputs_scatter: # GPU numbers
- ptb = PerturbationLpNorm(norm=inputs[0].ptb.norm, eps=inputs[0].ptb.eps, x_L=input_s[1], x_U=input_s[2])
- # bounded_inputs.append(tuple([(BoundedTensor(input_s[0][0], ptb))]))
- input_s = list(input_s[0])
- input_s[0] = BoundedTensor(input_s[0], ptb)
- input_s = tuple(input_s)
- bounded_inputs.append(input_s)
-
- # bounded_inputs = tuple(bounded_inputs)
- elif kwargs.get("x") is not None and hasattr(kwargs.get("x")[0], 'ptb') and kwargs.get("x")[0].ptb is not None:
- # compute bounds with x
- # kwargs['x'] is a normal tensor, we need to assign ptb to it
- x = kwargs.get("x")[0]
- bounded_inputs = []
- inputs_scatter, kwargs = self.scatter((inputs, x.ptb.x_L, x.ptb.x_U), kwargs, self.device_ids)
- for input_s, kw_s in zip(inputs_scatter, kwargs): # GPU numbers
- ptb = PerturbationLpNorm(norm=x.ptb.norm, eps=x.ptb.eps, x_L=input_s[1], x_U=input_s[2])
- kw_s['x'] = list(kw_s['x'])
- kw_s['x'][0] = BoundedTensor(kw_s['x'][0], ptb)
- kw_s['x'] = (kw_s['x'])
- bounded_inputs.append(tuple(input_s[0], ))
- else:
- # normal forward
- inputs_scatter, kwargs = self.scatter(inputs, kwargs, self.device_ids)
- bounded_inputs = inputs_scatter
-
- if len(self.device_ids) == 1:
- return self.module(*bounded_inputs[0], **kwargs[0])
- outputs = self.parallel_apply(self._replicas[:len(bounded_inputs)], bounded_inputs, kwargs)
- return self.gather(outputs, self.output_device)
-
- @staticmethod
- def get_property(model, node_class=None, att_name=None, node_name=None):
- if node_name:
- # Find node by name
- # FIXME If we use `model.named_modules()`, the nodes have the
- # `BoundedModule` type rather than bound nodes.
- for node in model._modules.values():
- if node.name == node_name:
- return getattr(node, att_name)
- else:
- # Find node by class
- for _, node in model.named_modules():
- # Find the Exp neuron in computational graph
- if isinstance(node, node_class):
- return getattr(node, att_name)
-
- def state_dict(self, destination=None, prefix='', keep_vars=False):
- # add 'module.' here before each keys in self.module.state_dict() if needed
- return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
-
- def _named_members(self, get_members_fn, prefix='', recurse=True):
- return self.module._named_members(get_members_fn, prefix, recurse)
-
+ from .interval_bound import (
+ IBP_general, _IBP_loss_fusion, check_IBP_intermediate,
+ check_IBP_first_linear)
+ from .forward_bound import (
+ forward_general, forward_general_dynamic, init_forward)
+ from .backward_bound import (
+ backward_general, get_sparse_C, check_optimized_variable_sparsity,
+ restore_sparse_bounds, get_alpha_crown_start_nodes,
+ get_unstable_locations, batched_backward)
+ from .optimized_bounds import get_optimized_bounds, init_slope
+ from .beta_crown import (
+ beta_bias, save_best_intermediate_betas,
+ print_optimized_beta)
+
+
+ from .solver_module import build_solver_module, _build_solver_input, _build_solver_general
diff --git a/auto_LiRPA/bound_multi_gpu.py b/auto_LiRPA/bound_multi_gpu.py
new file mode 100644
index 0000000..057ef42
--- /dev/null
+++ b/auto_LiRPA/bound_multi_gpu.py
@@ -0,0 +1,127 @@
+from torch.nn import DataParallel
+from .perturbations import *
+from .bounded_tensor import BoundedTensor
+from itertools import chain
+
+class BoundDataParallel(DataParallel):
+ # https://github.com/huanzhang12/CROWN-IBP/blob/master/bound_layers.py
+ # This is a customized DataParallel class for our project
+ def __init__(self, *inputs, **kwargs):
+ super(BoundDataParallel, self).__init__(*inputs, **kwargs)
+ self._replicas = None
+
+ # Overide the forward method
+ def forward(self, *inputs, **kwargs):
+ disable_multi_gpu = False # forward by single GPU
+ no_replicas = False # forward by multi GPUs but without replicate
+ if "disable_multi_gpu" in kwargs:
+ disable_multi_gpu = kwargs["disable_multi_gpu"]
+ kwargs.pop("disable_multi_gpu")
+
+ if "no_replicas" in kwargs:
+ no_replicas = kwargs["no_replicas"]
+ kwargs.pop("no_replicas")
+
+ if not self.device_ids or disable_multi_gpu:
+ if kwargs.pop("get_property", False):
+ return self.get_property(self, *inputs, **kwargs)
+ return self.module(*inputs, **kwargs)
+
+ if kwargs.pop("get_property", False):
+ if self._replicas is None:
+ assert 0, 'please call IBP/CROWN before get_property'
+ if len(self.device_ids) == 1:
+ return self.get_property(self.module, **kwargs)
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ kwargs = list(kwargs)
+ for i in range(len(kwargs)):
+ kwargs[i]['model'] = self._replicas[i]
+ outputs = self.parallel_apply([self.get_property] * len(kwargs), inputs, kwargs)
+ return self.gather(outputs, self.output_device)
+
+ # Only replicate during forward/IBP propagation. Not during interval bounds
+ # and CROWN-IBP bounds, since weights have not been updated. This saves 2/3
+ # of communication cost.
+ if not no_replicas:
+ if self._replicas is None: # first time
+ self._replicas = self.replicate(self.module, self.device_ids)
+ elif kwargs.get("method_opt", "forward") == "forward":
+ self._replicas = self.replicate(self.module, self.device_ids)
+ elif kwargs.get("x") is not None and kwargs.get("IBP") is True: #
+ self._replicas = self.replicate(self.module, self.device_ids)
+ # Update the input nodes to the ones within each replica respectively
+ for bounded_module in self._replicas:
+ for node in bounded_module._modules.values():
+ node.inputs = [bounded_module[name] for name in node.input_name]
+
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError("module must have its parameters and buffers "
+ "on device {} (device_ids[0]) but found one of "
+ "them on device: {}".format(self.src_device_obj, t.device))
+
+ # TODO: can be done in parallel, only support same ptb for all inputs per forward/IBP propagation
+ if len(inputs) > 0 and hasattr(inputs[0], 'ptb') and inputs[0].ptb is not None:
+ # compute bounds without x
+ # inputs_scatter is a normal tensor, we need to assign ptb to it if inputs is a BoundedTensor
+ inputs_scatter, kwargs = self.scatter((inputs, inputs[0].ptb.x_L, inputs[0].ptb.x_U), kwargs,
+ self.device_ids)
+ # inputs_scatter = inputs_scatter[0]
+ bounded_inputs = []
+ for input_s in inputs_scatter: # GPU numbers
+ # FIXME other perturbations are not supported yet
+ assert isinstance(inputs[0].ptb, PerturbationLpNorm)
+ ptb = PerturbationLpNorm(norm=inputs[0].ptb.norm, eps=inputs[0].ptb.eps, x_L=input_s[1], x_U=input_s[2])
+ input_s = list(input_s[0])
+ input_s[0] = BoundedTensor(input_s[0], ptb)
+ input_s = tuple(input_s)
+ bounded_inputs.append(input_s)
+
+ # bounded_inputs = tuple(bounded_inputs)
+ elif kwargs.get("x") is not None and hasattr(kwargs.get("x")[0], 'ptb') and kwargs.get("x")[0].ptb is not None:
+ # compute bounds with x
+ # kwargs['x'] is a normal tensor, we need to assign ptb to it
+ x = kwargs.get("x")[0]
+ bounded_inputs = []
+ inputs_scatter, kwargs = self.scatter((inputs, x.ptb.x_L, x.ptb.x_U), kwargs, self.device_ids)
+ for input_s, kw_s in zip(inputs_scatter, kwargs): # GPU numbers
+ # FIXME other perturbations are not supported yet
+ assert isinstance(x.ptb, PerturbationLpNorm)
+ ptb = PerturbationLpNorm(norm=x.ptb.norm, eps=x.ptb.eps, x_L=input_s[1], x_U=input_s[2])
+ kw_s['x'] = list(kw_s['x'])
+ kw_s['x'][0] = BoundedTensor(kw_s['x'][0], ptb)
+ kw_s['x'] = (kw_s['x'])
+ bounded_inputs.append(tuple(input_s[0], ))
+ else:
+ # normal forward
+ inputs_scatter, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ bounded_inputs = inputs_scatter
+
+ if len(self.device_ids) == 1:
+ return self.module(*bounded_inputs[0], **kwargs[0])
+ outputs = self.parallel_apply(self._replicas[:len(bounded_inputs)], bounded_inputs, kwargs)
+ return self.gather(outputs, self.output_device)
+
+ @staticmethod
+ def get_property(model, node_class=None, att_name=None, node_name=None):
+ if node_name:
+ # Find node by name
+ # FIXME If we use `model.named_modules()`, the nodes have the
+ # `BoundedModule` type rather than bound nodes.
+ for node in model._modules.values():
+ if node.name == node_name:
+ return getattr(node, att_name)
+ else:
+ # Find node by class
+ for _, node in model.named_modules():
+ # Find the Exp neuron in computational graph
+ if isinstance(node, node_class):
+ return getattr(node, att_name)
+
+ def state_dict(self, destination=None, prefix='', keep_vars=False):
+ # add 'module.' here before each keys in self.module.state_dict() if needed
+ return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
+
+ def _named_members(self, get_members_fn, prefix='', recurse=True):
+ return self.module._named_members(get_members_fn, prefix, recurse)
+
diff --git a/auto_LiRPA/bound_op_map.py b/auto_LiRPA/bound_op_map.py
index 087dab3..c33d52e 100644
--- a/auto_LiRPA/bound_op_map.py
+++ b/auto_LiRPA/bound_op_map.py
@@ -6,19 +6,7 @@
}
def register_custom_op(op_name: str, bound_obj: Bound) -> None:
- """ Register a custom operator.
-
- Args:
- op_name (str): Name of the custom operator
-
- bound_obj (Bound): The corresponding Bound class for the operator.
- """
bound_op_map[op_name] = bound_obj
def unregister_custom_op(op_name: str) -> None:
- """ Unregister a custom operator.
-
- Args:
- op_name (str): Name of the custom operator
- """
- bound_op_map.pop(op_name)
\ No newline at end of file
+ bound_op_map.pop(op_name)
diff --git a/auto_LiRPA/bounded_tensor.py b/auto_LiRPA/bounded_tensor.py
index ce922d9..cb3302f 100644
--- a/auto_LiRPA/bounded_tensor.py
+++ b/auto_LiRPA/bounded_tensor.py
@@ -1,3 +1,4 @@
+import copy
import torch
import torch.nn as nn
from torch import Tensor as Tensor
@@ -26,7 +27,7 @@ def __repr__(self):
return ''.format(super().__repr__())
def clone(self, *args, **kwargs):
- tensor = BoundedTensor(super().clone(*args, **kwargs), self.ptb)
+ tensor = BoundedTensor(super().clone(*args, **kwargs), copy.deepcopy(self.ptb))
return tensor
def _func(self, func, *args, **kwargs):
@@ -38,6 +39,13 @@ def _func(self, func, *args, **kwargs):
# Copy to other devices with perturbation
def to(self, *args, **kwargs):
+ # FIXME add a general "to" function in perturbation class, not here.
+ if hasattr(self.ptb, 'x_L') and isinstance(self.ptb.x_L, Tensor):
+ self.ptb.x_L = self.ptb.x_L.to(*args, **kwargs)
+ if hasattr(self.ptb, 'x_U') and isinstance(self.ptb.x_U, Tensor):
+ self.ptb.x_U = self.ptb.x_U.to(*args, **kwargs)
+ if hasattr(self.ptb, 'eps') and isinstance(self.ptb.eps, Tensor):
+ self.ptb.eps = self.ptb.eps.to(*args, **kwargs)
return self._func(super().to, *args, **kwargs)
@classmethod
diff --git a/auto_LiRPA/cuda/cuda_kernels.cu b/auto_LiRPA/cuda/cuda_kernels.cu
new file mode 100644
index 0000000..a67a30b
--- /dev/null
+++ b/auto_LiRPA/cuda/cuda_kernels.cu
@@ -0,0 +1,40 @@
+#include
+
+#include
+#include
+
+#include
+
+__global__ void cuda_double2float_rd_kernel(const double* __restrict__ inputs,
+ float* __restrict__ outputs, const size_t tensor_size) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < tensor_size) {
+ outputs[idx] = __double2float_rd(inputs[idx]);
+ }
+}
+
+__global__ void cuda_double2float_ru_kernel(const double* __restrict__ inputs,
+ float* __restrict__ outputs, const size_t tensor_size) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < tensor_size) {
+ outputs[idx] = __double2float_ru(inputs[idx]);
+ }
+}
+
+torch::Tensor cuda_double2float_forward(torch::Tensor input,
+ const std::string direction) {
+ auto total_elem = input.numel();
+ auto output = torch::empty_like(input, torch::ScalarType::Float);
+
+ const int threads = 1024;
+ const int blocks = (total_elem + threads - 1) / threads;
+
+ if (direction == "down") {
+ cuda_double2float_rd_kernel<<>>(input.data(), output.data(), total_elem);
+ }
+ else {
+ cuda_double2float_ru_kernel<<>>(input.data(), output.data(), total_elem);
+ }
+ return output;
+}
+
diff --git a/auto_LiRPA/cuda/cuda_utils.cpp b/auto_LiRPA/cuda/cuda_utils.cpp
new file mode 100644
index 0000000..867c910
--- /dev/null
+++ b/auto_LiRPA/cuda/cuda_utils.cpp
@@ -0,0 +1,25 @@
+#include
+
+#include
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+
+torch::Tensor cuda_double2float_forward(
+ torch::Tensor input, const std::string direction);
+
+torch::Tensor double2float_foward(
+ torch::Tensor input, const std::string direction) {
+ TORCH_CHECK((direction == "down") || (direction == "up"), "Unsupported direction, must be down or up.");
+ TORCH_CHECK(input.type().scalarType() == torch::ScalarType::Double, "This function only supports DoubleTensor as inputs.");
+ CHECK_CUDA(input);
+ return cuda_double2float_forward(input, direction);
+}
+
+/*
+ * Usage: double2float(tensor, direction)
+ * "tensor" must be a DoubleTensor on GPU.
+ * "direction" is a string, can be "up" (round up) or "down" (round down).
+ */
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("double2float", &double2float_foward, "Convert double to float with rounding direction control (direction = 'up' or 'down').");
+}
diff --git a/auto_LiRPA/cuda_utils.py b/auto_LiRPA/cuda_utils.py
new file mode 100644
index 0000000..912ac6c
--- /dev/null
+++ b/auto_LiRPA/cuda_utils.py
@@ -0,0 +1,126 @@
+import os
+import sys
+import time
+import torch
+from torch.utils.cpp_extension import load, BuildExtension, CUDAExtension
+from setuptools import setup
+
+class DummyCudaClass:
+ """A dummy class with error message when a CUDA function is called."""
+ def __getattr__(self, attr):
+ if attr == "double2float":
+ # When CUDA module is not built successfully, use a workaround.
+ def _f(x, d):
+ print('WARNING: Missing CUDA kernels. Please enable CUDA build by setting environment variable AUTOLIRPA_ENABLE_CUDA_BUILD=1 for the correct behavior!')
+ return x.float()
+ return _f
+ def _f(*args, **kwargs):
+ raise RuntimeError(f"method {attr} not available because CUDA module was not built.")
+ return _f
+
+if __name__ == "__main__" and len(sys.argv) > 1:
+ # Build and install native CUDA modules that can be directly imported later
+ print('Building and installing native CUDA modules...')
+ setup(
+ name='auto_LiRPA_cuda_utils',
+ ext_modules=[CUDAExtension('auto_LiRPA_cuda_utils', [
+ 'auto_LiRPA/cuda/cuda_utils.cpp',
+ 'auto_LiRPA/cuda/cuda_kernels.cu'
+ ])],
+ cmdclass={'build_ext': BuildExtension.with_options()},
+ )
+ exit(0)
+
+if torch.cuda.is_available() and os.environ.get('AUTOLIRPA_ENABLE_CUDA_BUILD', False):
+ try:
+ import auto_LiRPA_cuda_utils as _cuda_utils
+ except:
+ print('CUDA modules have not been installed')
+ try:
+ print('Building native CUDA modules...')
+ code_dir = os.path.dirname(os.path.abspath(__file__))
+ verbose = os.environ.get('AUTOLIRPA_DEBUG_CUDA_BUILD', None) is not None
+ _cuda_utils = load(
+ 'cuda_utils', [os.path.join(code_dir, 'cuda', 'cuda_utils.cpp'), os.path.join(code_dir, 'cuda', 'cuda_kernels.cu')], verbose=verbose)
+ print('CUDA modules have been built.')
+ except:
+ print('CUDA module build failure. Some features will be unavailable.')
+ print('Please make sure the latest CUDA toolkit is installed in your system.')
+ if verbose:
+ print(sys.exc_info()[2])
+ else:
+ print('Set environment variable AUTOLIRPA_DEBUG_CUDA_BUILD=1 to view build log.')
+ _cuda_utils = DummyCudaClass()
+else:
+ if os.environ.get('AUTOLIRPA_ENABLE_CUDA_BUILD', False):
+ print('CUDA unavailable. Some features are disabled.')
+ _cuda_utils = DummyCudaClass()
+
+double2float = _cuda_utils.double2float
+
+def test_double2float():
+ # Test the double2float function.
+ import time
+ shape = (3,4,5)
+
+ a = torch.randn(size=shape, dtype=torch.float64, device='cuda')
+ a = a.transpose(0,1)
+
+ au = _cuda_utils.double2float(a, "up")
+ ad = _cuda_utils.double2float(a, "down")
+
+ print(a.size(), au.size(), ad.size())
+
+ a_flatten = a.reshape(-1)
+ au_flatten = au.reshape(-1)
+ ad_flatten = ad.reshape(-1)
+
+ for i in range(a_flatten.numel()):
+ ai = a_flatten[i].item()
+ aui = au_flatten[i].item()
+ adi = ad_flatten[i].item()
+ print(adi, ai, aui)
+ assert adi <= ai
+ assert aui >= ai
+ del a, au, ad, a_flatten, au_flatten, ad_flatten
+
+ # Performance benchmark.
+ for j in [1, 4, 16, 64, 256, 1024]:
+ shape = (j, 512, 1024)
+ print(f'shape: {shape}')
+ t = torch.randn(size=shape, dtype=torch.float64, device='cuda')
+
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for i in range(10):
+ tt = t.float()
+ torch.cuda.synchronize()
+ del tt
+ pytorch_time = time.time() - start_time
+ print(f'pytorch rounding time: {pytorch_time:.4f}')
+
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for i in range(10):
+ tu = _cuda_utils.double2float(t, "up")
+ torch.cuda.synchronize()
+ del tu
+ roundup_time = time.time() - start_time
+ print(f'cuda round up time: {roundup_time:.4f}')
+
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for i in range(10):
+ td = _cuda_utils.double2float(t, "down")
+ torch.cuda.synchronize()
+ del td
+ rounddown_time = time.time() - start_time
+ print(f'cuda round down time: {rounddown_time:.4f}')
+
+ del t
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ # Some tests. It's not possible to test them automatically because travis does not have CUDA.
+ test_double2float()
diff --git a/auto_LiRPA/forward_bound.py b/auto_LiRPA/forward_bound.py
new file mode 100644
index 0000000..a090208
--- /dev/null
+++ b/auto_LiRPA/forward_bound.py
@@ -0,0 +1,306 @@
+from auto_LiRPA.beta_crown import print_optimized_beta
+import torch
+from torch import Tensor
+import warnings
+from .bound_ops import *
+from .utils import *
+from .backward_bound import batched_backward
+from .linear_bound import LinearBound
+from .perturbations import PerturbationLpNorm
+
+import sys
+sys.setrecursionlimit(1000000)
+
+def forward_general(self, C=None, node=None, concretize=False, offset=0):
+ if self.bound_opts['dynamic_forward']:
+ return self.forward_general_dynamic(C, node, concretize, offset)
+
+ if C is None:
+ if hasattr(node, 'linear'):
+ return node.linear.lower, node.linear.upper
+ if not node.from_input:
+ node.linear = LinearBound(None, node.value, None, node.value, node.value, node.value)
+ return node.value, node.value
+ if not node.perturbed:
+ node.lower = node.upper = self.get_forward_value(node)
+ if hasattr(node, 'lower'):
+ node.linear = LinearBound(None, node.lower, None, node.upper, node.lower, node.upper)
+ return node.lower, node.upper
+
+ for l_pre in node.inputs:
+ if not hasattr(l_pre, 'linear'):
+ self.forward_general(node=l_pre, offset=offset)
+ inp = [l_pre.linear for l_pre in node.inputs]
+ node._start = '_forward'
+ if (C is not None and isinstance(node, BoundLinear) and
+ not node.is_input_perturbed(1) and not node.is_input_perturbed(2)):
+ linear = node.bound_forward(self.dim_in, *inp, C=C)
+ C_merged = True
+ else:
+ linear = node.linear = node.bound_forward(self.dim_in, *inp)
+ C_merged = False
+
+ lw, uw = linear.lw, linear.uw
+ lower, upper = linear.lb, linear.ub
+
+ if C is not None and not C_merged:
+ # FIXME use bound_forward of BoundLinear
+ C_pos, C_neg = C.clamp(min=0), C.clamp(max=0)
+ _lw = torch.matmul(lw, C_pos.transpose(-1, -2)) + torch.matmul(uw, C_neg.transpose(-1, -2))
+ _uw = torch.matmul(uw, C_pos.transpose(-1, -2)) + torch.matmul(lw, C_neg.transpose(-1, -2))
+ lw, uw = _lw, _uw
+ _lower = torch.matmul(lower.unsqueeze(1), C_pos.transpose(-1, -2)) + \
+ torch.matmul(upper.unsqueeze(1), C_neg.transpose(-1, -2))
+ _upper = torch.matmul(upper.unsqueeze(1), C_pos.transpose(-1, -2)) + \
+ torch.matmul(lower.unsqueeze(1), C_neg.transpose(-1, -2))
+ lower, upper = _lower.squeeze(1), _upper.squeeze(1)
+
+ logger.debug(f'Forward bounds to {node}')
+
+ if concretize:
+ if lw is not None or uw is not None:
+ prev_dim_in = 0
+ batch_size = lw.shape[0]
+ assert (lw.ndim > 1)
+ lA = lw.reshape(batch_size, self.dim_in, -1).transpose(1, 2)
+ uA = uw.reshape(batch_size, self.dim_in, -1).transpose(1, 2)
+ for i in range(len(self.root)):
+ if hasattr(self.root[i], 'perturbation') and self.root[i].perturbation is not None:
+ _lA = lA[:, :, prev_dim_in : (prev_dim_in + self.root[i].dim)]
+ _uA = uA[:, :, prev_dim_in : (prev_dim_in + self.root[i].dim)]
+ lower = lower + self.root[i].perturbation.concretize(
+ self.root[i].center, _lA, sign=-1, aux=self.root[i].aux).view(lower.shape)
+ upper = upper + self.root[i].perturbation.concretize(
+ self.root[i].center, _uA, sign=+1, aux=self.root[i].aux).view(upper.shape)
+ prev_dim_in += self.root[i].dim
+ linear.lower, linear.upper = lower, upper
+
+ if C is None:
+ node.linear = linear
+ node.lower, node.upper = lower, upper
+
+ if self.bound_opts['forward_refinement']:
+ need_refinement = False
+ for out in node.output_name:
+ out_node = self[out]
+ if getattr(out_node, 'nonlinear', False):
+ need_refinement = True
+ for i in getattr(out_node, 'requires_input_bounds', []):
+ if out_node.inputs[i] == node:
+ need_refinement = True
+ break
+ if need_refinement:
+ forward_refinement(self, node)
+ return lower, upper
+
+
+def forward_general_dynamic(
+ self, C=None, node=None, concretize=False, offset=0):
+ max_dim = self.bound_opts['forward_max_dim']
+
+ if C is None:
+ if hasattr(node, 'linear'):
+ assert not concretize
+
+ linear = node.linear
+ if offset == 0:
+ if linear.lw is None:
+ return linear
+ elif linear.lw.shape[1] <= max_dim:
+ return linear
+ if linear.lw is not None:
+ lw = linear.lw[:, offset:offset+max_dim]
+ x_L = linear.x_L[:, offset:offset+max_dim]
+ x_U = linear.x_U[:, offset:offset+max_dim]
+ tot_dim = linear.tot_dim
+ if offset == 0:
+ lb = linear.lb
+ else:
+ lb = torch.zeros_like(linear.lb)
+ else:
+ lw = x_L = x_U = None
+ tot_dim = 0
+ lb = linear.lb
+ return LinearBound(
+ lw, lb, lw, lb, x_L=x_L, x_U=x_U,
+ offset=offset, tot_dim=tot_dim,
+ )
+
+ # These cases have no coefficient tensor
+ if not node.from_input:
+ if concretize:
+ return node.value, node.value
+ else:
+ node.linear = LinearBound(
+ None, node.value, None, node.value, node.value, node.value)
+ return node.linear
+ if not node.perturbed:
+ if not hasattr(node, 'lower'):
+ node.lower = node.upper = self.get_forward_value(node)
+ raise NotImplementedError
+ if concretize:
+ return node.lower, node.upper
+ else:
+ if offset > 0:
+ lb = torch.zeros_like(node.lower)
+ else:
+ lb = node.lower
+ node.linear = LinearBound(None, lb, None, lb, node.lower, node.upper)
+ return node.linear
+
+ if offset == 0:
+ logger.debug(f'forward_general_dynamic: node={node}')
+
+ inp = []
+ for l_pre in node.inputs:
+ linear_inp = self.forward_general_dynamic(node=l_pre, offset=offset)
+ linear_inp.lower = getattr(l_pre, 'lower', None)
+ linear_inp.upper = getattr(l_pre, 'upper', None)
+ inp.append(linear_inp)
+ node._start = '_forward'
+ if (C is not None and isinstance(node, BoundLinear) and
+ not node.is_input_perturbed(1) and not node.is_input_perturbed(2)):
+ linear = node.bound_dynamic_forward(
+ *inp, C=C, max_dim=max_dim, offset=offset)
+ C_merged = True
+ else:
+ linear = node.bound_dynamic_forward(
+ *inp, max_dim=max_dim, offset=offset)
+ C_merged = False
+ if offset > 0:
+ linear.lb = linear.ub = torch.zeros_like(linear.lb)
+
+ lw, lb, tot_dim = linear.lw, linear.lb, linear.tot_dim
+ #logger.debug(f'forward_general_dynamic: node={node}, w_size={lw.shape[1]}, tot_dim={tot_dim}')
+
+ if C is not None and not C_merged:
+ # FIXME use bound_forward of BoundLinear
+ lw = torch.matmul(lw, C.transpose(-1, -2))
+ lb = torch.matmul(lb.unsqueeze(1), C.transpose(-1, -2)).squeeze(1)
+
+ if concretize:
+ lower = upper = lb
+ if lw is not None:
+ batch_size = lw.shape[0]
+ assert (lw.ndim > 1)
+ if lw.shape[1] > 0:
+ A = lw.reshape(batch_size, lw.shape[1], -1).transpose(1, 2)
+ ptb = PerturbationLpNorm(x_L=linear.x_L, x_U=linear.x_U)
+ lower = lower + ptb.concretize(x=None, A=A, sign=-1).view(lb.shape)
+ upper = upper + ptb.concretize(x=None, A=A, sign=1).view(lb.shape)
+ offset_next = offset + max_dim
+ more = offset_next < tot_dim
+ else:
+ more = False
+
+ if C is None and offset == 0 and not more:
+ node.linear = linear
+
+ if more:
+ if lw is not None and lw.shape[1] > 0:
+ del A
+ del ptb
+ del lw
+ del linear
+ del inp
+ # TODO make it non-recursive
+ lower_next, upper_next = self.forward_general_dynamic(
+ C, node, concretize=True, offset=offset_next)
+ lower = lower + lower_next
+ upper = upper + upper_next
+
+ if C is None:
+ node.lower, node.upper = lower, upper
+
+ return lower, upper
+ else:
+ return linear
+
+
+def clean_memory(self, node):
+ """ Remove linear bounds that are no longer needed. """
+ # TODO add an option to retain these bounds
+
+ for inp in node.inputs:
+ if hasattr(inp, 'linear') and inp.linear is not None:
+ clean = True
+ for out in inp.output_name:
+ out_node = self[out]
+ if not (hasattr(out_node, 'linear') and out_node.linear is not None):
+ clean = False
+ if clean:
+ if isinstance(inp.linear, tuple):
+ for item in inp.linear:
+ del item
+ delattr(inp, 'linear')
+
+
+def forward_refinement(self, node):
+ """ Refine forward bounds with backward bound propagation
+ (only refine unstable positions). """
+ unstable_size_before = torch.logical_and(node.lower < 0, node.upper > 0).sum()
+ if unstable_size_before == 0:
+ return
+ unstable_idx, unstable_size = self.get_unstable_locations(
+ node.lower, node.upper, conv=isinstance(node, BoundConv))
+ logger.debug(f'Forward refinement for {node}')
+ batch_size = node.lower.shape[0]
+ ret = self.batched_backward(
+ node, C=None, unstable_idx=unstable_idx, batch_size=batch_size)
+ self.restore_sparse_bounds(
+ node, unstable_idx, unstable_size, node.lower, node.upper,
+ new_lower=ret[0], new_upper=ret[1])
+ unstable_size_after = torch.logical_and(node.lower < 0, node.upper > 0).sum()
+ logger.debug(f' Unstable neurons: {unstable_size_before} -> {unstable_size_after}')
+ # TODO also update linear bounds?
+
+
+def init_forward(self, root, dim_in):
+ if dim_in == 0:
+ raise ValueError("At least one node should have a specified perturbation")
+ prev_dim_in = 0
+ # Assumption: root[0] is the input node which implies batch_size
+ batch_size = root[0].value.shape[0]
+ dynamic = self.bound_opts['dynamic_forward']
+ for i in range(len(root)):
+ if hasattr(root[i], 'perturbation') and root[i].perturbation is not None:
+ shape = root[i].linear.lw.shape
+ if dynamic:
+ if shape[1] != dim_in:
+ raise NotImplementedError('Dynamic forward bound is not supported yet when there are multiple perturbed inputs.')
+ ptb = root[i].perturbation
+ if (type(ptb) != PerturbationLpNorm or ptb.norm < np.inf
+ or ptb.x_L is None or ptb.x_U is None):
+ raise NotImplementedError(
+ 'For dynamic forward bounds, only Linf (box) perturbations are supported, and x_L and x_U must be explicitly provided.')
+ root[i].linear.x_L = (
+ ptb.x_L_sparse.view(batch_size, -1) if ptb.sparse
+ else ptb.x_L.view(batch_size, -1))
+ root[i].linear.x_U = (
+ ptb.x_U_sparse.view(batch_size, -1) if ptb.sparse
+ else ptb.x_U.view(batch_size, -1))
+ else:
+ lw = torch.zeros(shape[0], dim_in, *shape[2:]).to(root[i].linear.lw)
+ lw[:, prev_dim_in:(prev_dim_in+shape[1])] = root[i].linear.lw
+ if root[i].linear.lw.data_ptr() == root[i].linear.uw.data_ptr():
+ uw = lw
+ else:
+ uw = torch.zeros(shape[0], dim_in, *shape[2:]).to(root[i].linear.uw)
+ uw[:, prev_dim_in:(prev_dim_in+shape[1])] = root[i].linear.uw
+ root[i].linear.lw = lw
+ root[i].linear.uw = uw
+ if i >= self.num_global_inputs:
+ root[i].forward_value = root[i].forward_value.unsqueeze(0).repeat(
+ *([batch_size] + [1] * self.forward_value.ndim))
+ prev_dim_in += shape[1]
+ else:
+ b = fv = root[i].forward_value
+ shape = fv.shape
+ if root[i].from_input:
+ w = torch.zeros(shape[0], dim_in, *shape[1:], device=self.device)
+ warnings.warn(f'Creating a LinearBound with zero weights with shape {w.shape}')
+ else:
+ w = None
+ root[i].linear = LinearBound(w, b, w, b, b, b)
+ root[i].lower = root[i].upper = b
+ root[i].interval = (root[i].lower, root[i].upper)
diff --git a/auto_LiRPA/interval_bound.py b/auto_LiRPA/interval_bound.py
new file mode 100644
index 0000000..0fd09c4
--- /dev/null
+++ b/auto_LiRPA/interval_bound.py
@@ -0,0 +1,147 @@
+import torch
+from .bound_ops import *
+
+
+def IBP_general(self, node=None, C=None, delete_bounds_after_use=False):
+
+ def _delete_unused_bounds(node_list):
+ """Delete bounds from input layers after use to save memory. Used when
+ sparse_intermediate_bounds_with_ibp is true."""
+ if delete_bounds_after_use:
+ for n in node_list:
+ del n.interval
+ del n.lower
+ del n.upper
+
+ if self.bound_opts.get('loss_fusion', False):
+ res = self._IBP_loss_fusion(node, C)
+ if res is not None:
+ return res
+
+ if not node.perturbed and hasattr(node, 'forward_value'):
+ node.lower, node.upper = node.interval = (
+ node.forward_value, node.forward_value)
+
+ to_be_deleted_bounds = []
+ if not hasattr(node, 'interval'):
+ for n in node.inputs:
+ if not hasattr(n, 'interval'):
+ # Node n does not have interval bounds; we must compute it.
+ self.IBP_general(
+ n, delete_bounds_after_use=delete_bounds_after_use)
+ to_be_deleted_bounds.append(n)
+ inp = [n_pre.interval for n_pre in node.inputs]
+ if (C is not None and isinstance(node, BoundLinear)
+ and not node.is_input_perturbed(1)):
+ # merge the last BoundLinear node with the specification, available
+ # when weights of this layer are not perturbed
+ ret = node.interval_propagate(*inp, C=C)
+ _delete_unused_bounds(to_be_deleted_bounds)
+ return ret
+ else:
+ node.interval = node.interval_propagate(*inp)
+
+ node.lower, node.upper = node.interval
+ if isinstance(node.lower, torch.Size):
+ node.lower = torch.tensor(node.lower)
+ node.interval = (node.lower, node.upper)
+ if isinstance(node.upper, torch.Size):
+ node.upper = torch.tensor(node.upper)
+ node.interval = (node.lower, node.upper)
+
+ if C is not None:
+ _delete_unused_bounds(to_be_deleted_bounds)
+ return BoundLinear.interval_propagate(None, node.interval, C=C)
+ else:
+ _delete_unused_bounds(to_be_deleted_bounds)
+ return node.interval
+
+def _IBP_loss_fusion(self, node, C):
+ """Merge BoundLinear, BoundGatherElements and BoundSub.
+
+ Improvement when loss fusion is used in training.
+ """
+
+ # not using loss fusion
+ if not self.bound_opts.get('loss_fusion', False):
+ return None
+
+ # Currently this function has issues in more complicated networks.
+ if self.bound_opts.get('no_ibp_loss_fusion', False):
+ return None
+
+ if (C is None and isinstance(node, BoundSub)
+ and isinstance(node.inputs[1], BoundGatherElements)
+ and isinstance(node.inputs[0], BoundLinear)):
+ node_gather = node.inputs[1]
+ node_linear = node.inputs[0]
+ node_start = node_linear.inputs[0]
+ w = node_linear.inputs[1].param
+ b = node_linear.inputs[2].param
+ labels = node_gather.inputs[1]
+ if not hasattr(node_start, 'interval'):
+ self.IBP_general(node_start)
+ for n in node_gather.inputs:
+ if not hasattr(n, 'interval'):
+ self.IBP_general(n)
+ if torch.isclose(labels.lower, labels.upper, 1e-8).all():
+ labels = labels.lower
+ batch_size = labels.shape[0]
+ w = w.expand(batch_size, *w.shape)
+ w = w - torch.gather(
+ w, dim=1,
+ index=labels.unsqueeze(-1).repeat(1, w.shape[1], w.shape[2]))
+ b = b.expand(batch_size, *b.shape)
+ b = b - torch.gather(b, dim=1,
+ index=labels.repeat(1, b.shape[1]))
+ lower, upper = node_start.interval
+ lower, upper = lower.unsqueeze(1), upper.unsqueeze(1)
+ node.lower, node.upper = node_linear.interval_propagate(
+ (lower, upper), (w, w), (b.unsqueeze(1), b.unsqueeze(1)))
+ node.interval = node.lower, node.upper = (
+ node.lower.squeeze(1), node.upper.squeeze(1))
+ return node.interval
+
+ return None
+
+
+def check_IBP_intermediate(self, node):
+ """ Check if we use IBP bounds to compute intermediate bounds on this node.
+ Basically we check if we can get bounds by only visiting operators in
+ `self.ibp_intermediate`.
+
+ Currently, assume all eligible operators have exactly one input. """
+ nodes = []
+ while not hasattr(node, 'lower') or not hasattr(node, 'upper'):
+ if type(node) not in self.ibp_intermediate:
+ return False
+ nodes.append(node)
+ node = node.inputs[0]
+ nodes.reverse()
+ for n in nodes:
+ node.interval = self.IBP_general(n)
+ return True
+
+
+def check_IBP_first_linear(self, node):
+ """Here we avoid creating a big C matrix in the first linear layer.
+ Disable this optimization when we have beta for intermediate layer bounds.
+ Disable this optimization when we need the A matrix of the first nonlinear
+ layer, forcibly use CROWN to record A matrix.
+ """
+ # This is the list of all intermediate layers where we need to refine.
+ if self.intermediate_constr is not None:
+ intermediate_beta_enabled_layers = [
+ k for v in self.intermediate_constr.values() for k in v]
+ else:
+ intermediate_beta_enabled_layers = []
+
+ if (node.name not in self.needed_A_dict.keys()
+ and (type(node) == BoundLinear
+ or type(node) == BoundConv
+ and node.name not in intermediate_beta_enabled_layers)):
+ if type(node.inputs[0]) == BoundInput:
+ node.lower, node.upper = self.IBP_general(node)
+ return True
+
+ return False
diff --git a/auto_LiRPA/linear_bound.py b/auto_LiRPA/linear_bound.py
new file mode 100644
index 0000000..ef50028
--- /dev/null
+++ b/auto_LiRPA/linear_bound.py
@@ -0,0 +1,33 @@
+class LinearBound:
+ def __init__(
+ self, lw=None, lb=None, uw=None, ub=None, lower=None, upper=None,
+ from_input=None, x_L=None, x_U=None, offset=0, tot_dim=None):
+ self.lw = lw
+ self.lb = lb
+ self.uw = uw
+ self.ub = ub
+ self.lower = lower
+ self.upper = upper
+ self.from_input = from_input
+ self.x_L = x_L
+ self.x_U = x_U
+ # Offset for input variables. Used for batched forward bound
+ # propagation.
+ self.offset = offset
+ if tot_dim is not None:
+ self.tot_dim = tot_dim
+ elif lw is not None:
+ self.tot_dim = lw.shape[1]
+ else:
+ self.tot_dim = 0
+
+ def is_single_bound(self):
+ """Check whether the linear lower bound and the linear upper bound are
+ the same."""
+ if (self.lw is not None and self.uw is not None
+ and self.lb is not None and self.ub is not None):
+ return (self.lw.data_ptr() == self.uw.data_ptr()
+ and self.lb.data_ptr() == self.ub.data_ptr()
+ and self.x_L is not None and self.x_U is not None)
+ else:
+ return True
diff --git a/auto_LiRPA/operators/__init__.py b/auto_LiRPA/operators/__init__.py
index 9020901..56cb4f3 100644
--- a/auto_LiRPA/operators/__init__.py
+++ b/auto_LiRPA/operators/__init__.py
@@ -1,7 +1,9 @@
from .base import *
from .linear import *
from .convolution import *
+from .pooling import *
from .activation import *
+from .nonlinear import *
from .bivariate import *
from .normalization import *
from .shape import *
@@ -11,5 +13,7 @@
from .constant import *
from .leaf import *
from .logical import *
-from .dropout import *
-from .dtype import *
\ No newline at end of file
+from .dropout import *
+from .dtype import *
+from .cut_ops import *
+from .solver_utils import grb
\ No newline at end of file
diff --git a/auto_LiRPA/operators/activation.py b/auto_LiRPA/operators/activation.py
index 7aa2dca..1383100 100644
--- a/auto_LiRPA/operators/activation.py
+++ b/auto_LiRPA/operators/activation.py
@@ -1,43 +1,74 @@
""" Activation operators or other unary nonlinear operators"""
+from typing import Optional, Tuple
+import torch
+from torch import Tensor
from .base import *
+from .clampmult import multiply_by_A_signs
+from .solver_utils import grb
+from ..utils import unravel_index, logger, prod
+
+
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_set_profiling_mode(False)
+
class BoundActivation(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.nonlinear = True
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.requires_input_bounds = [0]
self.relaxed = False
def _init_masks(self, x):
- self.mask_pos = torch.ge(x.lower, 0).to(torch.float)
- self.mask_neg = torch.le(x.upper, 0).to(torch.float)
- self.mask_both = 1 - self.mask_pos - self.mask_neg
+ self.mask_pos = x.lower >= 0
+ self.mask_neg = x.upper <= 0
+ self.mask_both = torch.logical_not(torch.logical_or(self.mask_pos, self.mask_neg))
- def _init_linear(self, x, dim_opt=None):
+ def init_linear_relaxation(self, x, dim_opt=None):
self._init_masks(x)
self.lw = torch.zeros_like(x.lower)
self.lb = self.lw.clone()
self.uw = self.lw.clone()
self.ub = self.lw.clone()
- def _add_linear(self, mask, type, k, x0, y0):
- if mask is None:
- mask = 1
+ def add_linear_relaxation(self, mask, type, k, x0, y0):
if type == 'lower':
w_out, b_out = self.lw, self.lb
else:
w_out, b_out = self.uw, self.ub
- w_out += mask * k
- b_out += mask * (-x0 * k + y0)
+
+ if mask is None:
+ if isinstance(k, Tensor) and k.ndim > 0:
+ w_out[:] = k
+ else:
+ w_out.fill_(k)
+ else:
+ if isinstance(k, Tensor):
+ w_out[..., mask] = k[..., mask].to(w_out)
+ else:
+ w_out[..., mask] = k
+
+ if (not isinstance(x0, Tensor) and x0 == 0
+ and not isinstance(y0, Tensor) and y0 == 0):
+ pass
+ else:
+ b = -x0 * k + y0
+ if mask is None:
+ if b.ndim > 0:
+ b_out[:] = b
+ else:
+ b_out.fill_(b)
+ else:
+ b_out[..., mask] = b[..., mask]
def bound_relax(self, x):
return not_implemented_op(self, 'bound_relax')
-
+
def interval_propagate(self, *v):
- return self.default_interval_propagate(*v)
+ return self.default_interval_propagate(*v)
def bound_backward(self, last_lA, last_uA, x):
if not self.relaxed:
- self._init_linear(x)
+ self.init_linear_relaxation(x)
self.bound_relax(x)
def _bound_oneside(last_A, sign=-1):
@@ -45,23 +76,25 @@ def _bound_oneside(last_A, sign=-1):
return None, 0
if sign == -1:
w_pos, b_pos, w_neg, b_neg = (
- self.lw.unsqueeze(0), self.lb.unsqueeze(0),
+ self.lw.unsqueeze(0), self.lb.unsqueeze(0),
self.uw.unsqueeze(0), self.ub.unsqueeze(0))
else:
w_pos, b_pos, w_neg, b_neg = (
- self.uw.unsqueeze(0), self.ub.unsqueeze(0),
+ self.uw.unsqueeze(0), self.ub.unsqueeze(0),
self.lw.unsqueeze(0), self.lb.unsqueeze(0))
+ w_pos = maybe_unfold_patches(w_pos, last_A)
+ w_neg = maybe_unfold_patches(w_neg, last_A)
+ b_pos = maybe_unfold_patches(b_pos, last_A)
+ b_neg = maybe_unfold_patches(b_neg, last_A)
if self.batch_dim == 0:
- _A = last_A.clamp(min=0) * w_pos + last_A.clamp(max=0) * w_neg
- _bias = last_A.clamp(min=0) * b_pos + last_A.clamp(max=0) * b_neg
- if _bias.ndim > 2:
- _bias = torch.sum(_bias, dim=list(range(2, _bias.ndim)))
+ _A, _bias = multiply_by_A_signs(last_A, w_pos, w_neg, b_pos, b_neg)
elif self.batch_dim == -1:
+ # FIXME: why this is different from above?
mask = torch.gt(last_A, 0.).to(torch.float)
_A = last_A * (mask * w_pos.unsqueeze(1) +
- (1 - mask) * w_neg.unsqueeze(1))
+ (1 - mask) * w_neg.unsqueeze(1))
_bias = last_A * (mask * b_pos.unsqueeze(1) +
- (1 - mask) * b_neg.unsqueeze(1))
+ (1 - mask) * b_neg.unsqueeze(1))
if _bias.ndim > 2:
_bias = torch.sum(_bias, dim=list(range(2, _bias.ndim)))
else:
@@ -74,29 +107,39 @@ def _bound_oneside(last_A, sign=-1):
return [(lA, uA)], lbias, ubias
+ @staticmethod
+ @torch.jit.script
+ def bound_forward_w(
+ relax_lw: Tensor, relax_uw: Tensor, x_lw: Tensor, x_uw: Tensor, dim: int):
+ lw = (relax_lw.unsqueeze(dim).clamp(min=0) * x_lw +
+ relax_lw.unsqueeze(dim).clamp(max=0) * x_uw)
+ uw = (relax_uw.unsqueeze(dim).clamp(max=0) * x_lw +
+ relax_uw.unsqueeze(dim).clamp(min=0) * x_uw)
+ return lw, uw
+
+ @staticmethod
+ @torch.jit.script
+ def bound_forward_b(
+ relax_lw: Tensor, relax_uw: Tensor, relax_lb: Tensor,
+ relax_ub: Tensor, x_lb: Tensor, x_ub: Tensor):
+ lb = relax_lw.clamp(min=0) * x_lb + relax_lw.clamp(max=0) * x_ub + relax_lb
+ ub = relax_uw.clamp(max=0) * x_lb + relax_uw.clamp(min=0) * x_ub + relax_ub
+ return lb, ub
+
def bound_forward(self, dim_in, x):
if not self.relaxed:
- self._init_linear(x)
+ self.init_linear_relaxation(x)
self.bound_relax(x)
- if self.lw.ndim > 0:
- if x.lw is not None:
- lw = self.lw.unsqueeze(1).clamp(min=0) * x.lw + \
- self.lw.unsqueeze(1).clamp(max=0) * x.uw
- uw = self.uw.unsqueeze(1).clamp(max=0) * x.lw + \
- self.uw.unsqueeze(1).clamp(min=0) * x.uw
- else:
- lw = uw = None
+ assert (x.lw is None) == (x.uw is None)
+
+ dim = 1 if self.lw.ndim > 0 else 0
+
+ if x.lw is not None:
+ lw, uw = BoundActivation.bound_forward_w(self.lw, self.uw, x.lw, x.uw, dim)
else:
- if x.lw is not None:
- lw = self.lw.unsqueeze(0).clamp(min=0) * x.lw + \
- self.lw.unsqueeze(0).clamp(max=0) * x.uw
- uw = self.uw.unsqueeze(0).clamp(min=0) * x.lw + \
- self.uw.unsqueeze(0).clamp(max=0) * x.uw
- else:
- lw = uw = None
- lb = self.lw.clamp(min=0) * x.lb + self.lw.clamp(max=0) * x.ub + self.lb
- ub = self.uw.clamp(max=0) * x.lb + self.uw.clamp(min=0) * x.ub + self.ub
+ lw = uw = None
+ lb, ub = BoundActivation.bound_forward_b(self.lw, self.uw, self.lb, self.ub, x.lb, x.ub)
return LinearBound(lw, lb, uw, ub)
@@ -106,47 +149,96 @@ def interval_propagate(self, *v):
class BoundOptimizableActivation(BoundActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- # Two stages: `init` (initializing parameters) and `opt` (optimizing parameters).
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ # Stages:
+ # * `init`: initializing parameters
+ # * `opt`: optimizing parameters
+ # * `reuse`: not optimizing parameters but reuse saved values
# If `None`, it means activation optimization is currently not used.
self.opt_stage = None
+ self.alpha = OrderedDict()
+ # Save patch sizes during bound_backward() for each output_node.
+ self.patch_size = {}
+ # Location of batch dimension in self.alpha. Must be set by children.
+ self.alpha_batch_dim = None
+ # A torch.bool mask of shape Tensor([batch_size]) that conditions the sample of alpha and beta to update
+ # If set to None, update all samples
+ # If not None, select those corresponding to 1 to update
+ self.alpha_beta_update_mask = None
- """ Enter the stage for initializing bound optimization. Optimized bounds are not used in
- this stage. """
def opt_init(self):
+ """Enter the stage for initializing bound optimization. Optimized bounds
+ are not used in this stage."""
self.opt_stage = 'init'
- """ Start optimizing bounds """
def opt_start(self):
+ """Start optimizing bounds."""
self.opt_stage = 'opt'
- """ start_nodes: a list of starting nodes [(node, size)] during
- CROWN backward bound propagation"""
+ def opt_reuse(self):
+ """ Reuse optimizing bounds """
+ self.opt_stage = 'reuse'
+
+ def opt_no_reuse(self):
+ """ Finish reusing optimized bounds """
+ if self.opt_stage == 'reuse':
+ self.opt_stage = None
+
+ def opt_end(self):
+ """ End optimizing bounds """
+ self.opt_stage = None
+
def init_opt_parameters(self, start_nodes):
+ """ start_nodes: a list of starting nodes [(node, size)] during
+ CROWN backward bound propagation"""
raise NotImplementedError
- def _init_linear(self, x, dim_opt=None):
+ def clip_alpha_(self):
+ pass
+
+ def init_linear_relaxation(self, x, dim_opt=None):
self._init_masks(x)
# The first dimension of size 2 is used for lA and uA respectively,
# when computing intermediate bounds.
- if self.opt_stage == 'opt' and dim_opt:
- self.lw = torch.zeros(2, dim_opt, *x.lower.shape).to(x.lower)
+ if self.opt_stage in ['opt', 'reuse'] and dim_opt is not None:
+ # For optimized bounds, we have independent lw for each output dimension for bound optimization.
+ # If the output layer is a fully connected layer, len(dim_opt) = 1.
+ # If the output layer is a conv layer, len(dim_opt) = 3 but we only use the out_c dimension to create slopes/bias.
+ # Variables are shared among out_h, out_w dimensions so far.
+ dim = dim_opt if isinstance(dim_opt, int) else dim_opt[0]
+ self.lw = torch.zeros(2, dim, *x.lower.shape).to(x.lower)
else:
+ # Without optimized bounds, the lw, lb (slope, biase) etc only depend on intermediate layer bounds,
+ # and are shared among different output dimensions.
self.lw = torch.zeros_like(x.lower)
self.lb = self.lw.clone()
self.uw = self.lw.clone()
- self.ub = self.lw.clone()
+ self.ub = self.lw.clone()
def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None):
- self._start = start_node
-
- if self.opt_stage != 'opt':
- return super().bound_backward(last_lA, last_uA, x)
- assert self.batch_dim == 0
+ self._start = start_node.name
+
+ if self.opt_stage not in ['opt', 'reuse']:
+ last_A = last_lA if last_lA is not None else last_uA
+ # Returned [(lA, uA)], lbias, ubias
+ As, lbias, ubias = super().bound_backward(last_lA, last_uA, x)
+ if isinstance(last_A, Patches):
+ A_prod = As[0][1].patches if As[0][0] is None else As[0][1].patches
+ # FIXME: Unify this function with BoundReLU
+ # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters.
+ if start_node is not None:
+ if last_A.unstable_idx is not None:
+ # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w).
+ self.patch_size[start_node.name] = [last_A.output_shape[1], A_prod.size(1), last_A.output_shape[2], last_A.output_shape[3], A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)]
+ else:
+ # Regular patches.
+ self.patch_size[start_node.name] = A_prod.size()
+ return As, lbias, ubias
+ assert self.batch_dim == 0
if not self.relaxed:
- self._init_linear(x, dim_opt=start_shape)
+ self.init_linear_relaxation(x, dim_opt=start_shape)
self.bound_relax(x)
def _bound_oneside(last_A, sign=-1):
@@ -156,11 +248,12 @@ def _bound_oneside(last_A, sign=-1):
w_pos, b_pos, w_neg, b_neg = self.lw[0], self.lb[0], self.uw[0], self.ub[0]
else:
w_pos, b_pos, w_neg, b_neg = self.uw[1], self.ub[1], self.lw[1], self.lb[1]
- _A = last_A.clamp(min=0) * w_pos + last_A.clamp(max=0) * w_neg
- _bias = last_A.clamp(min=0) * b_pos + last_A.clamp(max=0) * b_neg
- if _bias.ndim > 2:
- _bias = torch.sum(_bias, list(range(2, _bias.ndim)))
- return _A, _bias
+ w_pos = maybe_unfold_patches(w_pos, last_A)
+ w_neg = maybe_unfold_patches(w_neg, last_A)
+ b_pos = maybe_unfold_patches(b_pos, last_A)
+ b_neg = maybe_unfold_patches(b_neg, last_A)
+ A_prod, _bias = multiply_by_A_signs(last_A, w_pos, w_neg, b_pos, b_neg)
+ return A_prod, _bias
lA, lbias = _bound_oneside(last_lA, sign=-1)
uA, ubias = _bound_oneside(last_uA, sign=+1)
@@ -169,692 +262,722 @@ def _bound_oneside(last_A, sign=-1):
def _no_bound_parameters(self):
raise AttributeError('Bound parameters have not been initialized.'
- 'Please call `compute_bounds` with `method=CROWN-optimized`'
- ' at least once.')
+ 'Please call `compute_bounds` with `method=CROWN-optimized`'
+ ' at least once.')
-class BoundLeakyRelu(BoundActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.nonlinear = True
- self.options = options.get('relu')
- self.alpha = attr['alpha']
+ def dump_optimized_params(self):
+ raise NotImplementedError
- @Bound.save_io_shape
- def forward(self, x):
- return F.leaky_relu(x, negative_slope=self.alpha)
+ def restore_optimized_params(self):
+ raise NotImplementedError
- def bound_backward(self, last_lA, last_uA, x=None, start_node=None, start_shape=None):
- if x is not None:
- lb_r = x.lower.clamp(max=0)
- ub_r = x.upper.clamp(min=0)
- else:
- lb_r = self.lower.clamp(max=0)
- ub_r = self.upper.clamp(min=0)
- ub_r = torch.max(ub_r, lb_r + 1e-8)
- upper_d = (ub_r - self.alpha * lb_r) / (ub_r - lb_r)
- upper_b = - lb_r * upper_d + self.alpha * lb_r
+ def set_alpha_beta_update_mask(self, mask):
+ self.alpha_beta_update_mask = mask
- if self.options == "same-slope":
- # the same slope for upper and lower
- lower_d = upper_d
- elif self.options == "zero-lb":
- # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN
- lower_d = (upper_d >= 1.0).float() + (upper_d < 1.0).float() * self.alpha
- elif self.options == "one-lb":
- # Always use slope 1 as lower bound
- lower_d = (upper_d > 0.0).float() + (upper_d <= 0.0).float() * self.alpha
- else:
- lower_d = (upper_d > 0.5).float() + (upper_d <= 0.5).float() * self.alpha
-
- upper_d = upper_d.unsqueeze(0)
- lower_d = lower_d.unsqueeze(0)
- # Choose upper or lower bounds based on the sign of last_A
- uA = lA = None
- ubias = lbias = 0
- if last_uA is not None:
- neg_uA = last_uA.clamp(max=0)
- pos_uA = last_uA.clamp(min=0)
- uA = upper_d * pos_uA + lower_d * neg_uA
- ubias = self.get_bias(pos_uA, upper_b)
- if last_lA is not None:
- neg_lA = last_lA.clamp(max=0)
- pos_lA = last_lA.clamp(min=0)
- lA = upper_d * neg_lA + lower_d * pos_lA
- lbias = self.get_bias(neg_lA, upper_b)
- return [(lA, uA)], lbias, ubias
+ def clean_alpha_beta_update_mask(self):
+ self.alpha_beta_update_mask = None
class BoundRelu(BoundOptimizableActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.options = options
- self.relu_options = options.get('relu', 'adaptive')
+ self.relu_options = options.get('relu', 'adaptive') # FIXME: use better names.
+ self.use_sparse_spec_alpha = options.get('sparse_spec_alpha', False)
+ self.use_sparse_features_alpha = options.get('sparse_features_alpha', False)
self.beta = self.beta_mask = self.masked_beta = self.sparse_beta = None
self.split_beta_used = False
self.history_beta_used = False
self.flattened_nodes = None
# Save patches size for each output node.
self.patch_size = {}
+ self.cut_used = False
+ self.cut_module = None
+ # Alpha dimension is (2, output_shape, batch, *shape) for ReLU.
+ self.alpha_batch_dim = 2
def init_opt_parameters(self, start_nodes):
- self.alpha = OrderedDict()
ref = self.inputs[0].lower # a reference variable for getting the shape
- for ns, size_s in start_nodes:
- self.alpha[ns] = torch.empty([2, size_s, ref.size(0), *self.shape],
- dtype=torch.float, device=ref.device, requires_grad=True)
- for k, v in self.alpha.items():
- v.data.copy_(self.lower_d.data) # Initial from adaptive lower bounds.
+ batch_size = ref.size(0)
+ self.alpha = OrderedDict()
+ self.alpha_lookup_idx = OrderedDict() # For alpha with sparse spec dimention.
+ self.alpha_indices = None # indices of non-zero alphas.
+ verbosity = self.options.get('verbosity', 0)
+
+ # Alpha can be sparse in both spec dimension, and the C*H*W dimension.
+ # We first deal with the sparse-feature alpha, which is sparse in the
+ # C*H*W dimesnion of this layer.
+ minimum_sparsity = self.options.get('minimum_sparsity', 0.9)
+ if (hasattr(self.inputs[0], 'lower') and hasattr(self.inputs[0], 'upper')
+ and self.use_sparse_features_alpha):
+ # Pre-activation bounds available, we will store the alpha for unstable neurons only.
+ # Since each element in a batch can have different unstable neurons,
+ # for simplicity we find a super-set using any(dim=0).
+ # This can be non-ideal if the x in a batch are very different.
+ self.alpha_indices = torch.logical_and(
+ self.inputs[0].lower < 0, self.inputs[0].upper > 0).any(dim=0).nonzero(as_tuple=True)
+ total_neuron_size = self.inputs[0].lower.numel() // batch_size
+ if self.alpha_indices[0].size(0) <= minimum_sparsity * total_neuron_size:
+ # Shape is the number of unstable neurons in this layer.
+ alpha_shape = [self.alpha_indices[0].size(0)]
+ # Skip the batch, spec dimension, and find the lower slopes for all unstable neurons.
+ if len(self.alpha_indices) == 1:
+ # This layer is after a linear layer.
+ alpha_init = self.lower_d[:, :, self.alpha_indices[0]]
+ elif len(self.alpha_indices) == 3:
+ # This layer is after a conv layer.
+ alpha_init = self.lower_d[
+ :, :, self.alpha_indices[0], self.alpha_indices[1],
+ self.alpha_indices[2]]
+ else:
+ raise ValueError
+ if verbosity > 0:
+ print(f'layer {self.name} using sparse-features alpha with shape {alpha_shape}; unstable size {self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({ref.shape})')
+ else:
+ alpha_shape = self.shape # Full alpha.
+ alpha_init = self.lower_d
+ if verbosity > 0:
+ print(f'layer {self.name} using full alpha with shape {alpha_shape}; unstable size {self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({ref.shape})')
+ self.alpha_indices = None # Use full alpha.
+ else:
+ alpha_shape = self.shape # Full alpha.
+ alpha_init = self.lower_d
+ # Now we start to create alphas for all start nodes.
+ # When sparse-spec feature is enabled, alpha is created for only
+ # unstable neurons in start node.
+ for ns, output_shape, unstable_idx in start_nodes:
+ if isinstance(output_shape, (list, tuple)):
+ if len(output_shape) > 1:
+ size_s = prod(output_shape) # Conv layers.
+ else:
+ size_s = output_shape[0]
+ else:
+ size_s = output_shape
+ # unstable_idx may be a tensor (dense layer or conv layer
+ # with shared alpha), or tuple of 3-d tensors (conv layer with
+ # non-sharing alpha).
+ sparsity = float('inf') if unstable_idx is None else unstable_idx.size(0) if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].size(0)
+ if sparsity <= minimum_sparsity * size_s and self.use_sparse_spec_alpha:
+ if verbosity > 0:
+ print(f'layer {self.name} start_node {ns} using sparse-spec alpha with unstable size {sparsity} total_size {size_s} output_shape {output_shape}')
+ # For fully connected layer, or conv layer with shared alpha per channel.
+ # shape is (2, sparse_spec, batch, this_layer_shape)
+ # We create sparse specification dimension, where the spec dimension of alpha only includes slopes for unstable neurons in start_node.
+ self.alpha[ns] = torch.empty([2, sparsity + 1, batch_size, *alpha_shape],
+ dtype=torch.float, device=ref.device, requires_grad=True)
+ self.alpha[ns].data.copy_(alpha_init.data) # This will broadcast to (2, sparse_spec) dimensions.
+ # unstable_idx is a list of used neurons (or channels for BoundConv) for the start_node.
+ assert unstable_idx.ndim == 1 if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].ndim == 1
+ # We only need to the alpha for the unstable neurons in start_node.
+ indices = torch.arange(1, sparsity + 1, device=alpha_init.device, dtype=torch.long)
+ if isinstance(output_shape, int) or len(output_shape) == 1:
+ # Fully connected layers, or conv layer in patches mode with partially shared alpha (pixels in the same channel use the same alpha).
+ self.alpha_lookup_idx[ns] = torch.zeros(size_s, dtype=torch.long, device=alpha_init.device)
+ # This lookup table maps the unstable_idx to the actual alpha location in self.alpha[ns].
+ # Note that self.alpha[ns][:,0] is reserved for any unstable neurons that are not found in the lookup table. This usually should not
+ # happen, unless reference bounds are not properly set.
+ self.alpha_lookup_idx[ns].data[unstable_idx] = indices
+ else:
+ # conv layer in matrix mode, or in patches mode but with non-shared alpha. The lookup table is 3-d.
+ assert len(output_shape) == 3
+ self.alpha_lookup_idx[ns] = torch.zeros(output_shape, dtype=torch.long, device=alpha_init.device)
+ if isinstance(unstable_idx, torch.Tensor):
+ # Convert the unstable index from flattend 1-d to 3-d. (matrix mode).
+ unstable_idx_3d = unravel_index(unstable_idx, output_shape)
+ else:
+ # Patches mode with non-shared alpha, unstable_idx is already 3d.
+ unstable_idx_3d = unstable_idx
+ # Build look-up table.
+ self.alpha_lookup_idx[ns].data[unstable_idx_3d[0], unstable_idx_3d[1], unstable_idx_3d[2]] = indices
+ else:
+ if verbosity > 0:
+ print(f'layer {self.name} start_node {ns} using full alpha with unstable size {sparsity if unstable_idx is not None else None} total_size {size_s} output_shape {output_shape}')
+ # alpha shape is (2, spec, batch, this_layer_shape). "this_layer_shape" may still be sparse.
+ self.alpha[ns] = torch.empty([2, size_s, batch_size, *alpha_shape],
+ dtype=torch.float, device=ref.device, requires_grad=True)
+ self.alpha[ns].data.copy_(alpha_init.data) # This will broadcast to (2, spec) dimensions
+ # alpha_lookup_idx can be used for checking if sparse alpha is used or not.
+ self.alpha_lookup_idx[ns] = None
+
+ def clip_alpha_(self):
+ for v in self.alpha.values():
+ v.data = torch.clamp(v.data, 0., 1.)
- @Bound.save_io_shape
def forward(self, x):
self.shape = x.shape[1:]
if self.flattened_nodes is None:
self.flattened_nodes = x[0].reshape(-1).shape[0]
return F.relu(x)
- # Linear relaxation for nonlinear functions
- # Used for forward mode bound propagation
- def bound_relax(self, x):
- # FIXME maybe avoid using `mask` which looks inefficient
- # m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99)
- self._add_linear(mask=self.mask_neg, type='lower',
- k=torch.zeros_like(x.lower), x0=0, y0=0)
- self._add_linear(mask=self.mask_neg, type='upper',
- k=torch.zeros_like(x.lower), x0=0, y0=0)
- self._add_linear(mask=self.mask_pos, type='lower',
- k=torch.ones_like(x.lower), x0=0, y0=0)
- self._add_linear(mask=self.mask_pos, type='upper',
- k=torch.ones_like(x.lower), x0=0, y0=0)
- upper = torch.max(x.upper, x.lower + 1e-8)
- delta = 1e-8
- r = (x.upper - x.lower).clamp(min=delta)
- upper_k = x.upper / r + delta / r
- self._add_linear(mask=self.mask_both, type='upper',
- k=upper_k, x0=x.lower, y0=0)
- if self.relu_options == "same-slope":
+ def _forward_relaxation(self, x):
+ self._init_masks(x)
+ self.mask_pos = self.mask_pos.to(x.lower)
+ self.mask_both = self.mask_both.to(x.lower)
+
+ upper_k, upper_b = self._relu_upper_bound(x.lower, x.upper)
+ self.uw = self.mask_pos + self.mask_both * upper_k
+ self.ub = self.mask_both * upper_b
+
+ if self.opt_stage in ['opt', 'reuse']:
+ # Each actual alpha in the forward mode has shape (batch_size, *relu_node_shape].
+ # But self.alpha has shape (2, output_shape, batch_size, *relu_node_shape]
+ # and we do not need its first two dimensions.
+ lower_k = alpha = self.alpha['_forward'][0, 0]
+ elif self.relu_options == "same-slope":
lower_k = upper_k
elif self.relu_options == "zero-lb":
lower_k = torch.zeros_like(upper_k)
elif self.relu_options == "one-lb":
lower_k = torch.ones_like(upper_k)
- elif self.opt_stage == 'opt':
- # Each actual alpha in the forward mode has shape (batch_size, *relu_node_shape].
- # But self.alpha has shape (2, output_shape, batch_size, *relu_node_shape]
- # and we do not need its first two dimensions.
- lower_k = alpha = self.alpha['_forward'][0, 0]
else:
# adaptive
lower_k = torch.gt(torch.abs(x.upper), torch.abs(x.lower)).to(torch.float)
# NOTE #FIXME Saved for initialization bounds for optimization.
# In the backward mode, same-slope bounds are used.
# But here it is using adaptive bounds which seem to be better
- # for nn4sys benchmark with loose input bounds. Need confirmation
+ # for nn4sys benchmark with loose input bounds. Need confirmation
# for other cases.
- self.d = lower_k.detach() # saved for initializing optimized bounds
- self._add_linear(mask=self.mask_both, type='lower',
- k=lower_k, x0=0., y0=0.)
+ self.lower_d = lower_k.detach() # saved for initializing optimized bounds
- def bound_backward(self, last_lA, last_uA, x=None, start_node=None, beta_for_intermediate_layers=False, unstable_idx=None):
- if x is not None:
- # # only set lower and upper bound here when using neuron set version, ie, not ob_update_by_layer
- # if self.beta is not None and not self.options.get('optimize_bound_args', {}).get('ob_update_by_layer', False):
- # if self.beta_mask.abs().sum() != 0:
- # # set bound neuron-wise according to beta_mask
- # x.lower = x.lower * (self.beta_mask != 1).to(torch.float32)
- # x.upper = x.upper * (self.beta_mask != -1).to(torch.float32)
+ self.lw = self.mask_both * lower_k + self.mask_pos
- lb_r = x.lower.clamp(max=0)
- ub_r = x.upper.clamp(min=0)
- else:
- lb_r = self.lower.clamp(max=0)
- ub_r = self.upper.clamp(min=0)
+ def bound_dynamic_forward(self, x, max_dim=None, offset=0):
+ self._init_masks(x)
+ self.mask_pos = self.mask_pos.to(x.lower)
+ self.mask_both = self.mask_both.to(x.lower)
+
+ upper_k, upper_b = self._relu_upper_bound(x.lower, x.upper)
+ w_new = (self.mask_pos.unsqueeze(1) * x.lw
+ + self.mask_both.unsqueeze(1) * upper_k.unsqueeze(1) * x.lw)
+ upper_b = self.mask_both * upper_b / 2
+ b_new = (self.mask_pos * x.lb
+ + self.mask_both * upper_k * x.lb + upper_b)
+
+ # Create new variables for unstable ReLU
+ batch_size = w_new.shape[0]
+ device = w_new.device
+ unstable = self.mask_both.view(batch_size, -1)
+ tot_unstable = int(unstable.sum(dim=-1).max())
+ tot_dim = x.tot_dim + tot_unstable
+ # logger.debug(f'Unstable: {tot_unstable}')
+
+ if offset + w_new.shape[1] < x.tot_dim:
+ return LinearBound(
+ w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=tot_dim)
+
+ index = torch.cumsum(unstable, dim=-1).to(torch.int64)
+ index = (index - (offset + w_new.shape[1] - x.tot_dim)).clamp(min=0)
+ num_new_dim = int(index.max())
+ num_new_dim_actual = min(num_new_dim, max_dim - w_new.shape[1])
+ index = index.clamp(max=num_new_dim_actual+1)
+ w_unstable = torch.zeros(batch_size, num_new_dim_actual + 2, unstable.size(-1), device=device)
+ x_L_unstable = -torch.ones(batch_size, num_new_dim_actual, device=device)
+ x_U_unstable = torch.ones(batch_size, num_new_dim_actual, device=device)
+ w_unstable.scatter_(dim=1, index=index.unsqueeze(1), src=upper_b.view(batch_size, 1, -1), reduce='add')
+ w_unstable = w_unstable[:, 1:-1].view(batch_size, num_new_dim_actual, *w_new.shape[2:])
+
+ w_new = torch.cat([w_new, w_unstable], dim=1)
+ x_L_new = torch.cat([x.x_L, x_L_unstable], dim=-1)
+ x_U_new = torch.cat([x.x_U, x_U_unstable], dim=-1)
+
+ return LinearBound(
+ w_new, b_new, w_new, b_new, x_L=x_L_new, x_U=x_U_new, tot_dim=tot_dim)
- self.I = ((lb_r != 0) * (ub_r != 0)).detach() # unstable neurons
- # print('unstable neurons:', self.I.sum())
- if hasattr(x, 'interval') and Interval.use_relative_bounds(x.interval):
- diff_x = x.interval.upper_offset - x.interval.lower_offset
- upper_d = (self.interval.upper_offset - self.interval.lower_offset) / diff_x.clamp(min=epsilon)
- mask_tiny_diff = (diff_x <= epsilon).float()
- upper_d = mask_tiny_diff * F.relu(x.upper) + (1 - mask_tiny_diff) * upper_d
+ def bound_forward(self, dim_in, x):
+ self._forward_relaxation(x)
+
+ lb = self.lw * x.lb
+ ub = self.uw * x.ub + self.ub
+
+ if x.lw is not None:
+ lw = self.lw.unsqueeze(1) * x.lw
+ else:
+ lw = None
+ if x.uw is not None:
+ uw = self.uw.unsqueeze(1) * x.uw
else:
- ub_r = torch.max(ub_r, lb_r + 1e-8)
- upper_d = ub_r / (ub_r - lb_r)
+ uw = None
+
+ if not lw.requires_grad:
+ del self.mask_both, self.mask_pos
+ del self.lw, self.uw, self.ub
+
+ return LinearBound(lw, lb, uw, ub)
+
+ @staticmethod
+ @torch.jit.script
+ def _relu_upper_bound(lb, ub):
+ """Upper bound slope and intercept according to CROWN relaxation."""
+ # TODO: pre-comple all JIT functions before run.
+ lb_r = lb.clamp(max=0)
+ ub_r = ub.clamp(min=0)
+ ub_r = torch.max(ub_r, lb_r + 1e-8)
+ upper_d = ub_r / (ub_r - lb_r)
upper_b = - lb_r * upper_d
+ return upper_d, upper_b
+
+ @staticmethod
+ def _relu_mask_alpha(lower, upper, lb_lower_d : Optional[Tensor], ub_lower_d : Optional[Tensor]) -> Tuple[Optional[Tensor], Optional[Tensor], Tensor]:
+ lower_mask = (lower >= 0).requires_grad_(False).to(lower.dtype)
+ upper_mask = (upper <= 0).requires_grad_(False)
+ zero_coeffs = upper_mask.all()
+ no_mask = (1. - lower_mask) * (1. - upper_mask.to(upper.dtype))
+ if lb_lower_d is not None:
+ lb_lower_d = torch.clamp(lb_lower_d, min=0., max=1.) * no_mask + lower_mask
+ if ub_lower_d is not None:
+ ub_lower_d = torch.clamp(ub_lower_d, min=0., max=1.) * no_mask + lower_mask
+ return lb_lower_d, ub_lower_d, zero_coeffs
+
+ def _backward_relaxation(self, last_lA, last_uA, x, start_node, unstable_idx):
+ if x is not None:
+ lower = x.lower
+ upper = x.upper
+ else:
+ lower = self.lower
+ upper = self.upper
+
+ # Upper bound slope and intercept according to CROWN relaxation.
+ upper_d, upper_b = self._relu_upper_bound(lower, upper)
flag_expand = False
ub_lower_d = lb_lower_d = None
- if self.relu_options == "same-slope":
- # the same slope for upper and lower
- lower_d = upper_d
- elif self.relu_options == "zero-lb":
- # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN
- lower_d = (upper_d >= 1.0).float()
- elif self.relu_options == "one-lb":
- # Always use slope 1 as lower bound
- lower_d = (upper_d > 0.0).float()
- elif self.relu_options == "reversed-adaptive":
- lower_d = (upper_d < 0.5).float()
- elif self.opt_stage == 'opt':
+ lower_b = None # ReLU does not have lower bound intercept (=0).
+ alpha_lookup_idx = None # For sparse-spec alpha.
+ if self.opt_stage in ['opt', 'reuse']:
# Alpha-CROWN.
lower_d = None
# Each alpha has shape (2, output_shape, batch_size, *relu_node_shape].
# If slope is shared, output_shape will be 1.
+ # The *relu_node_shape might be sparse (sparse-feature alpha), where the non-zero values are indicated by self.alpha_indices.
+ # The out_shape might be sparse (sparse-spec alpha), where the non-zero values are indexed by self.alpha_lookup_idx.
if unstable_idx is not None and self.alpha[start_node.name].size(1) != 1:
+ # print(f'relu layer {self.name}, start_node {start_node}, unstable_idx {type(unstable_idx)} alpha idx {self.alpha_lookup_idx[start_node.name].size()}')
+ alpha_lookup_idx = self.alpha_lookup_idx[start_node.name]
if isinstance(unstable_idx, tuple):
- if isinstance(last_lA, torch.Tensor) or isinstance(last_uA, torch.Tensor):
- # Patches mode converted to matrix. Need to select accross the spec dimension.
+ # Start node is a conv node.
+ selected_alpha = self.alpha[start_node.name]
+ if isinstance(last_lA, Tensor) or isinstance(last_uA, Tensor):
+ # Start node is a conv node but we received tensors as A matrices.
+ # Patches mode converted to matrix, or matrix mode used. Need to select accross the spec dimension.
# For this node, since it is in matrix mode, the spec dimension is out_c * out_h * out_w
- selected_alpha = self.alpha[start_node.name]
- # Reshape the spec dimension to c*h*w so we can select based on unstable index.
- selected_alpha = selected_alpha.view(-1, *start_node.output_shape[1:], *selected_alpha.shape[2:])
- selected_alpha = selected_alpha[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]]
+ # Shape is [2, spec, batch, *this_layer_shape]
+ if alpha_lookup_idx is None:
+ # Reshape the spec dimension to c*h*w so we can select used alphas based on unstable index.
+ # Shape becomes [2, out_c, out_h, out_w, batch, *this_layer_shape]
+ selected_alpha = selected_alpha.view(selected_alpha.size(0), *start_node.output_shape[1:], *selected_alpha.shape[2:])
+ selected_alpha = selected_alpha[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]]
+ else:
+ assert alpha_lookup_idx.ndim == 3
+ # We only stored some alphas, and A is also sparse, so the unstable_idx must be first translated to real indices.
+ # alpha shape is (2, sparse_spec_shape, batch_size, *relu_node_shape) where relu_node_shape can also be sparse.
+ # We use sparse-spec alphas. Need to convert these unstable_idx[0], unstable_idx[1], unstable_idx[0] using lookup table.
+ _unstable_idx = alpha_lookup_idx[unstable_idx[0], unstable_idx[1], unstable_idx[2]]
+ selected_alpha = self.non_deter_index_select(selected_alpha, index=_unstable_idx, dim=1)
else:
- # unstable index for patches mode. Need to choose based on unstable_idx.
- # Selection is just based on output channel, and it will be selected when d_pos_unfolded_r and d_neg_unfolded_r are constructed.
- selected_alpha = self.alpha[start_node.name]
+ # Patches mode. Alpha must be selected after unfolding, so cannot be done here.
+ # Selection is deferred to maybe_unfold() using alpha_lookup_idx.
+ # For partially shared alpha, its shape is (2, out_c, batch_size, *relu_node_shape).
+ # For full alpha, its shape is (2, out_c*out_h*out_w, batch_size, *relu_node_shape).
+ # Both the spec dimension and relu_node_shape dimensions can be sparse.
+ pass
elif unstable_idx.ndim == 1:
+ # Start node is a FC node.
# Only unstable neurons of the start_node neurons are used.
- selected_alpha = self.non_deter_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1)
+ assert alpha_lookup_idx is None or alpha_lookup_idx.ndim == 1
+ _unstable_idx = alpha_lookup_idx[unstable_idx] if alpha_lookup_idx is not None else unstable_idx
+ selected_alpha = self.non_deter_index_select(self.alpha[start_node.name], index=_unstable_idx, dim=1)
elif unstable_idx.ndim == 2:
+ assert alpha_lookup_idx is None, "sparse spec alpha has not been implemented yet."
# Each element in the batch selects different neurons.
selected_alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1)
else:
raise ValueError
else:
+ # Spec dimension is dense. Alpha must not be created sparsely.
+ assert self.alpha_lookup_idx[start_node.name] is None
selected_alpha = self.alpha[start_node.name]
- # print(f'{self.name} selecting {start_node.name} alpha {selected_alpha.size()}')
# The first dimension is lower/upper intermediate bound.
- if x is not None:
- lower = x.lower
- upper = x.upper
- else:
- lower = self.lower
- upper = self.upper
- lower_mask = lower > 0
- upper_mask = upper < 0
if last_lA is not None:
- lb_lower_d = selected_alpha[0].clamp(min=0.0, max=1.0)
- lb_lower_d[:, lower_mask] = 1.0
- lb_lower_d[:, upper_mask] = 0.0
+ lb_lower_d = selected_alpha[0]
if last_uA is not None:
- ub_lower_d = selected_alpha[1].clamp(min=0.0, max=1.0)
- ub_lower_d[:, lower_mask] = 1.0
- ub_lower_d[:, upper_mask] = 0.0
- self.zero_backward_coeffs_l = self.zero_backward_coeffs_u = upper_mask.all().item()
- flag_expand = True
+ ub_lower_d = selected_alpha[1]
+
+ if self.alpha_indices is not None:
+ # Sparse alpha on the hwc dimension. We store slopes for unstable neurons in this layer only.
+ # Recover to full alpha first.
+ def reconstruct_full_alpha(sparse_alpha, full_alpha_shape, alpha_indices):
+ full_alpha = torch.zeros(full_alpha_shape, dtype=sparse_alpha.dtype, device=sparse_alpha.device)
+ if len(alpha_indices) == 1:
+ # Relu after a dense layer.
+ full_alpha[:, :, alpha_indices[0]] = sparse_alpha
+ elif len(alpha_indices) == 3:
+ # Relu after a conv layer.
+ full_alpha[:, :, alpha_indices[0], alpha_indices[1], alpha_indices[2]] = sparse_alpha
+ else:
+ raise ValueError
+ return full_alpha
+ sparse_alpha_shape = lb_lower_d.shape if lb_lower_d is not None else ub_lower_d.shape
+ full_alpha_shape = sparse_alpha_shape[:-1] + self.shape
+ if lb_lower_d is not None:
+ lb_lower_d = reconstruct_full_alpha(lb_lower_d, full_alpha_shape, self.alpha_indices)
+ if ub_lower_d is not None:
+ ub_lower_d = reconstruct_full_alpha(ub_lower_d, full_alpha_shape, self.alpha_indices)
+
+ # condition only on the masked part
+ if self.alpha_beta_update_mask is not None:
+ if lb_lower_d is not None:
+ lb_lower_d_new = lb_lower_d[:, self.alpha_beta_update_mask]
+ else:
+ lb_lower_d_new = None
+ if ub_lower_d is not None:
+ ub_lower_d_new = ub_lower_d[:, self.alpha_beta_update_mask]
+ else:
+ ub_lower_d_new = None
+ lb_lower_d, ub_lower_d, zero_coeffs = self._relu_mask_alpha(lower, upper, lb_lower_d_new, ub_lower_d_new)
+ else:
+ lb_lower_d, ub_lower_d, zero_coeffs = self._relu_mask_alpha(lower, upper, lb_lower_d, ub_lower_d)
+ self.zero_backward_coeffs_l = self.zero_backward_coeffs_u = zero_coeffs
+ flag_expand = True # we already have the spec dimension.
+ elif self.relu_options == "same-slope":
+ # the same slope for upper and lower
+ lower_d = upper_d
+ elif self.relu_options == "zero-lb":
+ # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN
+ lower_d = (upper_d >= 1.0).to(upper_d.dtype)
+ elif self.relu_options == "one-lb":
+ # Always use slope 1 as lower bound
+ lower_d = (upper_d > 0.0).to(upper_d.dtype)
+ elif self.relu_options == "reversed-adaptive":
+ lower_d = (upper_d < 0.5).to(upper_d.dtype)
else:
# adaptive
- lower_d = (upper_d > 0.5).float()
-
- # save for calculate babsr score
- self.d = upper_d
- self.lA = last_lA
- # Save for initialization bounds.
- self.lower_d = lower_d
-
- # assert self.I.sum() == torch.logical_and(0 < self.d, self.d < 1).sum()
+ lower_d = (upper_d > 0.5).to(upper_d.dtype)
# Upper bound always needs an extra specification dimension, since they only depend on lb and ub.
upper_d = upper_d.unsqueeze(0)
upper_b = upper_b.unsqueeze(0)
if not flag_expand:
- if self.opt_stage == 'opt':
+ if self.opt_stage in ['opt', 'reuse']:
# We have different slopes for lower and upper bounds propagation.
lb_lower_d = lb_lower_d.unsqueeze(0) if last_lA is not None else None
ub_lower_d = ub_lower_d.unsqueeze(0) if last_uA is not None else None
else:
lower_d = lower_d.unsqueeze(0)
+ return upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d, alpha_lookup_idx
- mode = "patches" if isinstance(last_lA, Patches) or isinstance(last_uA, Patches) else "matrix"
-
- # In patches mode, we need to unfold lower and upper slopes. In matrix mode we simply return.
- def _maybe_unfold(d_tensor, last_A):
- if mode == "matrix" or d_tensor is None or last_A is None:
- return d_tensor
- # Input are slopes with shape (spec, batch, input_c, input_h, input_w)
- # Here spec is the same as out_c.
- assert d_tensor.ndim == 5
- d_shape = d_tensor.size()
- # Reshape to 4-D tensor to unfold.
- d_tensor = d_tensor.view(-1, *d_shape[-3:])
- # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W).
- d_unfolded = inplace_unfold(d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding)
- # Reshape to (spec, batch, out_h, out_w, in_c, H, W); here spec_size is out_c.
- d_unfolded_r = d_unfolded.view(*d_shape[:-3], *d_unfolded.shape[1:])
- if last_A.unstable_idx is not None:
- if d_unfolded_r.size(0) == 1:
- # Broadcast the spec shape, so only need to select the reset dimensions.
- # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W).
- d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)
- d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]
- # output shape: (unstable_size, batch, in_c, H, W).
- else:
- d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]
- # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W).
- # For regular patches, the shape after unfold is (spec, batch, out_h, out_w, in_c, H, W).
- return d_unfolded_r
+ def bound_backward(self, last_lA, last_uA, x=None, start_node=None, beta_for_intermediate_layers=False, unstable_idx=None):
+ # Get element-wise CROWN linear relaxations.
+ upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d, alpha_lookup_idx = \
+ self._backward_relaxation(last_lA, last_uA, x, start_node, unstable_idx)
+ # save for calculate babsr score
+ self.d = upper_d
+ self.lA = last_lA
+ # Save for initialization bounds.
+ self.lower_d = lower_d
# Choose upper or lower bounds based on the sign of last_A
def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg):
if last_A is None:
return None, 0
+ # Obtain the new linear relaxation coefficients based on the signs in last_A.
+ _A, _bias = multiply_by_A_signs(last_A, d_pos, d_neg, b_pos, b_neg)
+ if isinstance(last_A, Patches):
+ # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters.
+ A_prod = _A.patches
+ if start_node is not None:
+ if last_A.unstable_idx is not None:
+ # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w).
+ self.patch_size[start_node.name] = [last_A.output_shape[1], A_prod.size(1), last_A.output_shape[2], last_A.output_shape[3], A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)]
+ else:
+ # Regular patches.
+ self.patch_size[start_node.name] = A_prod.size()
+ return _A, _bias
- if type(last_A) == torch.Tensor:
- # multiply according to sign of A (we use fused operation to save memory)
- # neg_A = last_A.clamp(max=0)
- # pos_A = last_A.clamp(min=0)
- # A = d_pos * pos_A + d_neg * neg_A
- A, pos_A, neg_A = self.clamp_mutiply(last_A, d_pos, d_neg)
- bias = 0
- if b_pos is not None:
- bias = bias + torch.einsum('sb...,sb...->sb', pos_A, b_pos)
- if b_neg is not None:
- bias = bias + torch.einsum('sb...,sb...->sb', neg_A, b_neg)
- return A, bias
- elif type(last_A) == Patches:
- # if last_A is not an identity matrix
- assert last_A.identity == 0
- if last_A.identity == 0:
- # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension.
- # or (unstable_size, batch_size, in_c, H, W) when it is sparse.
- patches = last_A.patches
- prod, pos_A_patches, neg_A_patches = self.clamp_mutiply_non_contiguous(patches, d_pos, d_neg)
- # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse.
-
- # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters.
- if start_node is not None:
- if last_A.unstable_idx is not None:
- # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w).
- self.patch_size[start_node.name] = [last_A.output_shape[1], prod.size(1), last_A.output_shape[2], last_A.output_shape[3], prod.size(-3), prod.size(-2), prod.size(-1)]
- else:
- # Regular patches.
- self.patch_size[start_node.name] = prod.size()
-
- bias = 0
- if b_pos is not None:
- # For sparse patches the return bias size is (unstable_size, batch).
- # For regular patches the return bias size is (spec, batch, out_h, out_w).
- bias = bias + torch.einsum('sb...chw,sb...chw->sb...', b_pos, pos_A_patches)
- if b_neg is not None:
- bias = bias + torch.einsum('sb...chw,sb...chw->sb...', b_neg, neg_A_patches)
- return Patches(prod, last_A.stride, last_A.padding, prod.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape), bias
+ ######## A problem with patches mode for cut constraint start ##########
+ # There are cases that the node that is in the constraint but not selected by the patches for the output node
+ # trick: only count the small patches that have all the split node coeffs[ci].sum() equal to coeffs_unfolded[ci][out_h, out_w, -1].sum()
+ # we should force these beta to be 0 to disable the effect of these constraints
+ A = last_lA if last_lA is not None else last_uA
+ current_layer_shape = x.lower.size()[1:]
+ if self.cut_used and type(A) is Patches:
+ self.cut_module.patch_trick(start_node, self.name, A, current_layer_shape)
+ ######## A problem with patches mode for cut constraint end ##########
+
+ if self.cut_used:
+ # propagate postrelu node in cut constraints
+ last_lA, last_uA = self.cut_module.relu_cut(
+ start_node, self.name, last_lA, last_uA, current_layer_shape, unstable_idx,
+ batch_mask=self.alpha_beta_update_mask)
# In patches mode we might need an unfold.
- upper_d = _maybe_unfold(upper_d, last_lA if last_lA is not None else last_uA)
- lower_d = _maybe_unfold(lower_d, last_lA if last_lA is not None else last_uA)
- upper_b = _maybe_unfold(upper_b, last_lA if last_lA is not None else last_uA)
- ub_lower_d = _maybe_unfold(ub_lower_d, last_uA)
- lb_lower_d = _maybe_unfold(lb_lower_d, last_lA)
-
- uA, ubias = _bound_oneside(last_uA, upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, None)
- lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, upper_d, None, upper_b)
-
- self.masked_beta_lower = self.masked_beta_upper = None
- if self.options.get('optimize_bound_args', {}).get('ob_beta', False):
- if self.options.get('optimize_bound_args', {}).get('ob_single_node_split', False):
- # Beta-CROWN.
- A = last_uA if last_uA is not None else last_lA
- if type(A) is Patches:
+ # lower_d, upper_d, lower_b, upper_b: 1, batch, current_c, current_w, current_h or None
+ upper_d = maybe_unfold_patches(upper_d, last_lA if last_lA is not None else last_uA)
+ lower_d = maybe_unfold_patches(lower_d, last_lA if last_lA is not None else last_uA)
+ upper_b = maybe_unfold_patches(upper_b, last_lA if last_lA is not None else last_uA)
+ lower_b = maybe_unfold_patches(lower_b, last_lA if last_lA is not None else last_uA) # for ReLU it is always None; keeping it here for completeness.
+ # ub_lower_d and lb_lower_d might have sparse spec dimension, so they may need alpha_lookup_idx to convert to actual spec dim.
+ ub_lower_d = maybe_unfold_patches(ub_lower_d, last_uA, alpha_lookup_idx=alpha_lookup_idx)
+ # optimizable slope lb_lower_d: spec (only channels in spec layer), batch, current_c, current_w, current_h
+ # patches mode lb_lower_d after unfold: unstable, batch, in_C, H, W
+ lb_lower_d = maybe_unfold_patches(lb_lower_d, last_lA, alpha_lookup_idx=alpha_lookup_idx)
+
+ if self.cut_used:
+ I = (x.lower < 0) * (x.upper > 0)
+ # propagate integer var of relu neuron (arelu) in cut constraints through relu layer
+ lA, uA, lbias, ubias = self.cut_module.arelu_cut(
+ start_node, self.name, last_lA, last_uA, lower_d, upper_d,
+ lower_b, upper_b, lb_lower_d, ub_lower_d, I, x, self.patch_size,
+ current_layer_shape, unstable_idx,
+ batch_mask=self.alpha_beta_update_mask)
+ else:
+ uA, ubias = _bound_oneside(
+ last_uA, upper_d, ub_lower_d if lower_d is None else lower_d,
+ upper_b, lower_b)
+ lA, lbias = _bound_oneside(
+ last_lA, lb_lower_d if lower_d is None else lower_d, upper_d,
+ lower_b, upper_b)
+
+ # Regular Beta CROWN with single neuron split
+ def _beta_crown_single_neuron_splits(A, uA, lA, unstable_idx):
+ if type(A) is Patches:
+ if self.options.get('enable_opt_interm_bounds', False):
+ # expand sparse_beta to full beta
+ beta_values = (self.sparse_beta[start_node.name] * self.sparse_beta_sign[start_node.name])
+ beta_indices = self.sparse_beta_loc[start_node.name]
+ self.masked_beta = torch.zeros(2, *self.shape).reshape(2, -1).to(A.patches.dtype)
+ self.non_deter_scatter_add(self.masked_beta, dim=1, index=beta_indices, src=beta_values.to(self.masked_beta.dtype))
+ self.masked_beta = self.masked_beta.reshape(2, *self.shape)
+ else:
+ if self.beta is None:
+ # Beta not used.
+ return lA, uA
# For patches mode, masked_beta will be used; sparse beta is not supported.
self.masked_beta = (self.beta[0] * self.beta_mask).requires_grad_()
- A_patches = A.patches
- # unfold the beta as patches, size (batch, out_h, out_w, in_c, H, W)
- masked_beta_unfolded = inplace_unfold(self.masked_beta, kernel_size=A_patches.shape[-2:], padding=A.padding, stride=A.stride)
- if A.unstable_idx is not None:
- masked_beta_unfolded = masked_beta_unfolded.permute(1, 2, 0, 3, 4)
- # After selection, the shape is (unstable_size, batch, in_c, H, W).
- masked_beta_unfolded = masked_beta_unfolded[A.unstable_idx[1], A.unstable_idx[2]]
- else:
- # Add the spec (out_c) dimension.
- masked_beta_unfolded = masked_beta_unfolded.unsqueeze(0)
- if uA is not None:
- uA = Patches(uA.patches + masked_beta_unfolded, uA.stride, uA.padding, uA.patches.shape, unstable_idx=uA.unstable_idx, output_shape=uA.output_shape)
- if lA is not None:
- lA = Patches(lA.patches - masked_beta_unfolded, lA.stride, lA.padding, lA.patches.shape, unstable_idx=lA.unstable_idx, output_shape=lA.output_shape)
- elif type(A) is torch.Tensor:
+ # unfold the beta as patches, size (batch, out_h, out_w, in_c, H, W)
+ A_patches = A.patches
+ masked_beta_unfolded = inplace_unfold(self.masked_beta, kernel_size=A_patches.shape[-2:], padding=A.padding, stride=A.stride, inserted_zeros=A.inserted_zeros, output_padding=A.output_padding)
+ if A.unstable_idx is not None:
+ masked_beta_unfolded = masked_beta_unfolded.permute(1, 2, 0, 3, 4, 5)
+ # After selection, the shape is (unstable_size, batch, in_c, H, W).
+ masked_beta_unfolded = masked_beta_unfolded[A.unstable_idx[1], A.unstable_idx[2]]
+ else:
+ # Add the spec (out_c) dimension.
+ masked_beta_unfolded = masked_beta_unfolded.unsqueeze(0)
+ if self.alpha_beta_update_mask is not None:
+ masked_beta_unfolded = masked_beta_unfolded[self.alpha_beta_update_mask]
+ if uA is not None:
+ uA = uA.create_similar(uA.patches + masked_beta_unfolded)
+ if lA is not None:
+ lA = lA.create_similar(lA.patches - masked_beta_unfolded)
+ elif type(A) is Tensor:
+ if self.options.get('enable_opt_interm_bounds', False):
+ # For matrix mode, beta is sparse.
+ beta_values = (self.sparse_beta[start_node.name] * self.sparse_beta_sign[start_node.name]).expand(lA.size(0), -1, -1)
+ # self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension.
+ beta_indices = self.sparse_beta_loc[start_node.name].unsqueeze(0).expand(lA.size(0), -1, -1)
+ else:
# For matrix mode, beta is sparse.
beta_values = (self.sparse_beta * self.sparse_beta_sign).expand(lA.size(0), -1, -1)
# self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension.
beta_indices = self.sparse_beta_loc.unsqueeze(0).expand(lA.size(0), -1, -1)
- # For conv layer, the last dimension is flattened in indices.
- prev_size = A.size()
- if uA is not None:
- uA = self.non_deter_scatter_add(uA.view(uA.size(0), uA.size(1), -1), dim=2, index=beta_indices, src=beta_values)
- uA = uA.view(prev_size)
- if lA is not None:
- lA = self.non_deter_scatter_add(lA.view(lA.size(0), lA.size(1), -1), dim=2, index=beta_indices, src=beta_values.neg())
- lA = lA.view(prev_size)
- else:
- raise RuntimeError(f"Unknown type {type(A)} for A")
- # The code block below is for debugging and will be removed (until the end of this function).
- elif not self.options.get('optimize_bound_args', {}).get('ob_single_node_split', True):
- A = uA if uA is not None else lA
- if type(A) == torch.Tensor:
- device = A.device
- else:
- device = A.patches.device
- print_time = False
-
- if self.single_beta_used or self.split_beta_used or self.history_beta_used:
- start_time = time.time()
- history_compute_time, split_compute_time, split_convert_time = 0, 0, 0
- history_compute_time1, history_compute_time2 = 0, 0
- # assert len(self.split_beta) > 0, "split_beta_used or history_beta_used is True means there have to be one relu in one batch is used in split constraints"
- if self.single_beta_used:
- if beta_for_intermediate_layers:
- # We handle the refinement of intermediate layer after this split layer here. (the refinement for intermediate layers before the split is handled in compute_bounds().
- # print(f'single node beta for {start_node.name} with beta shape {self.single_intermediate_betas[start_node.name]["ub"].size()}')
- assert not self.history_beta_used
- assert not self.history_beta_used
- assert type(A) is not Patches
- if uA is not None:
- # The beta for start_node has shape ([batch, prod(start_node.shape), n_max_history_beta])
- single_intermediate_beta = self.single_intermediate_betas[start_node.name]['ub']
- single_intermediate_beta = single_intermediate_beta.view(
- single_intermediate_beta.size(0), -1, single_intermediate_beta.size(-1))
- if unstable_idx is not None:
- # Only unstable neurons of the start_node neurons are used.
- single_intermediate_beta = self.non_deter_index_select(single_intermediate_beta, index=unstable_idx, dim=1)
- # This is the sign.
- single_intermediate_beta = single_intermediate_beta * self.single_beta_sign.unsqueeze(1)
- # We now generate a large matrix in shape (batch, prod(start_node.shape), prod(nodes)) which is the same size as uA and lA.
- prev_size = uA.size()
- # self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension.
- indices = self.single_beta_loc.unsqueeze(0).expand(uA.size(0), -1, -1)
- # We update uA here directly using sparse operation. Note the spec dimension is at the first!
- uA = self.non_deter_scatter_add(uA.view(uA.size(0), uA.size(1), -1), dim=2, index=indices, src=single_intermediate_beta.transpose(0,1))
- uA = uA.view(prev_size)
- if lA is not None:
- # The beta for start_node has shape ([batch, prod(start_node.shape), n_max_history_beta])
- single_intermediate_beta = self.single_intermediate_betas[start_node.name]['lb']
- single_intermediate_beta = single_intermediate_beta.view(
- single_intermediate_beta.size(0), -1, single_intermediate_beta.size(-1))
- if unstable_idx is not None:
- # Only unstable neurons of the start_node neurons are used.
- single_intermediate_beta = self.non_deter_index_select(single_intermediate_beta, index=unstable_idx, dim=1)
- # This is the sign, for lower bound we need to negate.
- single_intermediate_beta = single_intermediate_beta * ( - self.single_beta_sign.unsqueeze(1))
- # We now generate a large matrix in shape (batch, prod(start_node.shape), prod(nodes)) which is the same size as uA and lA.
- prev_size = lA.size()
- # self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension.
- indices = self.single_beta_loc.unsqueeze(0).expand(lA.size(0), -1, -1)
- # We update lA here directly using sparse operation. Note the spec dimension is at the first!
- lA = self.non_deter_scatter_add(lA.view(lA.size(0), lA.size(1), -1), dim=2, index=indices, src=single_intermediate_beta.transpose(0,1))
- lA = lA.view(prev_size)
- else:
- self.masked_beta_lower = self.masked_beta_upper = self.masked_beta = self.beta * self.beta_mask
-
- ############################
- # sparse_coo version for history coeffs
- if self.history_beta_used:
- # history_compute_time = time.time()
- if beta_for_intermediate_layers:
- # print(f'history intermediate beta for {start_node.name} with beta shape {self.history_intermediate_betas[start_node.name]["ub"].size()}')
- if uA is not None:
- # The beta for start_node has shape ([batch, prod(start_node.shape), n_max_history_beta])
- history_intermediate_beta = self.history_intermediate_betas[start_node.name]['ub']
- history_intermediate_beta = history_intermediate_beta.view(
- history_intermediate_beta.size(0), -1, history_intermediate_beta.size(-1))
- if unstable_idx is not None:
- # Only unstable neurons of the start_node neurons are used.
- history_intermediate_beta = self.non_deter_index_select(history_intermediate_beta, index=unstable_idx, dim=1)
- # new_history_coeffs has shape (batch, prod(nodes), n_max_history_beta)
- # new_history_c has shape (batch, n_max_history_beta)
- # This can generate a quite large matrix in shape (batch, prod(start_node.shape), prod(nodes)) which is the same size as uA and lA.
- self.masked_beta_upper = torch.bmm(history_intermediate_beta, (
- self.new_history_coeffs * self.new_history_c.unsqueeze(1)).transpose(-1,
- -2))
- if lA is not None:
- history_intermediate_beta = self.history_intermediate_betas[start_node.name]['lb']
- history_intermediate_beta = history_intermediate_beta.view(
- history_intermediate_beta.size(0), -1, history_intermediate_beta.size(-1))
- if unstable_idx is not None:
- # Only unstable neurons of the start_node neurons are used.
- history_intermediate_beta = self.non_deter_index_select(history_intermediate_beta, index=unstable_idx, dim=1)
- self.masked_beta_lower = torch.bmm(history_intermediate_beta, (
- self.new_history_coeffs * self.new_history_c.unsqueeze(1)).transpose(-1,
- -2))
- else:
- # new_history_coeffs has shape (batch, prod(nodes), n_max_history_beta)
- # new_history_beta has shape (batch, m_max_history_beta)
- self.masked_beta_lower = self.masked_beta_upper = torch.bmm(self.new_history_coeffs, (
- self.new_history_beta * self.new_history_c).unsqueeze(-1)).squeeze(-1)
-
- # new split constraint
- if self.split_beta_used:
- split_convert_time = time.time()
- if self.split_coeffs["dense"] is None:
- assert not hasattr(self, 'split_intermediate_betas') # intermediate beta split must use the dense mode.
- ##### we can use repeat to further save the conversion time
- # since the new split constraint coeffs can be optimized, we can just save the index and assign optimized coeffs value to the sparse matrix
- self.new_split_coeffs = torch.zeros(self.split_c.size(0), self.flattened_nodes,
- dtype=torch.get_default_dtype(), device=device)
- # assign coeffs value to the first half batch
- self.new_split_coeffs[
- (self.split_coeffs["nonzero"][:, 0], self.split_coeffs["nonzero"][:, 1])] = \
- self.split_coeffs["coeffs"]
- # # assign coeffs value to the rest half batch with the same values since split constraint shared the same coeffs for >0/<0
- self.new_split_coeffs[(self.split_coeffs["nonzero"][:, 0] + int(self.split_c.size(0) / 2),
- self.split_coeffs["nonzero"][:, 1])] = self.split_coeffs["coeffs"]
- else:
- # batch = int(self.split_c.size(0)/2)
- # assign coeffs value to the first half batch and the second half batch
- self.new_split_coeffs = self.split_coeffs["dense"].repeat(2, 1)
- split_convert_time = time.time() - split_convert_time
- split_compute_time = time.time()
- if beta_for_intermediate_layers:
- assert hasattr(self, 'split_intermediate_betas')
- # print(f'split intermediate beta for {start_node.name} with beta shape {self.split_intermediate_betas[start_node.name]["ub"].size()}')
- if uA is not None:
- # upper bound betas for this set of intermediate neurons.
- # Make an extra spec dimension. Now new_split_coeffs has size (batch, specs, #nodes). Specs is the number of intermediate neurons of start node. The same split will be applied to all specs in a batch element.
- # masked_beta_upper has shape (batch, spec, #nodes)
- split_intermediate_betas = self.split_intermediate_betas[start_node.name]['ub']
- split_intermediate_betas = split_intermediate_betas.view(split_intermediate_betas.size(0), -1, split_intermediate_betas.size(-1))
- if unstable_idx is not None:
- # Only unstable neurons of the start_node neurons are used.
- split_intermediate_betas = self.non_deter_index_select(split_intermediate_betas, index=unstable_idx, dim=1)
- self.split_masked_beta_upper = split_intermediate_betas * (
- self.new_split_coeffs * self.split_c).unsqueeze(1)
- if lA is not None:
- split_intermediate_betas = self.split_intermediate_betas[start_node.name]['lb']
- split_intermediate_betas = split_intermediate_betas.view(split_intermediate_betas.size(0), -1, split_intermediate_betas.size(-1))
- if unstable_idx is not None:
- # Only unstable neurons of the start_node neurons are used.
- split_intermediate_betas = self.non_deter_index_select(split_intermediate_betas, index=unstable_idx, dim=1)
- self.split_masked_beta_lower = split_intermediate_betas * (
- self.new_split_coeffs * self.split_c).unsqueeze(1)
- else:
- # beta for final objective only. TODO: distinguish between lb and ub.
- self.split_masked_beta_upper = self.split_masked_beta_lower = self.new_split_coeffs * (
- self.split_beta * self.split_c)
- # add the new split constraint beta to the masked_beta
- if self.masked_beta_upper is None:
- self.masked_beta_upper = self.split_masked_beta_upper
- else:
- self.masked_beta_upper = self.masked_beta_upper + self.split_masked_beta_upper
+ # For conv layer, the last dimension is flattened in indices.
+ prev_size = A.size()
+ if self.alpha_beta_update_mask is not None:
+ beta_indices = beta_indices[:, self.alpha_beta_update_mask]
+ beta_values = beta_values[:, self.alpha_beta_update_mask]
+ if uA is not None:
+ uA = self.non_deter_scatter_add(uA.view(uA.size(0), uA.size(1), -1), dim=2, index=beta_indices, src=beta_values.to(uA.dtype))
+ uA = uA.view(prev_size)
+ if lA is not None:
+ lA = self.non_deter_scatter_add(lA.view(lA.size(0), lA.size(1), -1), dim=2, index=beta_indices, src=beta_values.neg().to(lA.dtype))
+ lA = lA.view(prev_size)
+ else:
+ raise RuntimeError(f"Unknown type {type(A)} for A")
+ return lA, uA
- if self.masked_beta_lower is None:
- self.masked_beta_lower = self.split_masked_beta_lower
- else:
- self.masked_beta_lower = self.masked_beta_lower + self.split_masked_beta_lower
- # For backwards compatibility - we originally only have one beta.
- self.masked_beta = self.masked_beta_lower
- split_compute_time = time.time() - split_compute_time
-
- A = last_uA if last_uA is not None else last_lA
- if type(A) is Patches:
- assert not hasattr(self, 'split_intermediate_betas')
- assert not hasattr(self, 'single_intermediate_betas')
- A_patches = A.patches
- # Reshape beta to image size.
- self.masked_beta = self.masked_beta.view(self.masked_beta.size(0), *ub_r.size()[1:])
- # unfold the beta as patches, size (batch, out_h, out_w, in_c, H, W)
- masked_beta_unfolded = inplace_unfold(self.masked_beta, kernel_size=A_patches.shape[-2:], padding=A.padding, stride=A.stride)
- if A.unstable_idx is not None:
- masked_beta_unfolded = masked_beta_unfolded.permute(1, 2, 0, 3, 4)
- # After selection, the shape is (unstable_size, batch, in_c, H, W).
- masked_beta_unfolded = masked_beta_unfolded[A.unstable_idx[1], A.unstable_idx[2]]
- else:
- # Add the spec (out_c) dimension.
- masked_beta_unfolded = masked_beta_unfolded.unsqueeze(0)
- if uA is not None:
- uA = Patches(uA.patches + masked_beta_unfolded, uA.stride, uA.padding, uA.patches.shape, unstable_idx=uA.unstable_idx, output_shape=uA.output_shape)
- if lA is not None:
- lA = Patches(lA.patches - masked_beta_unfolded, lA.stride, lA.padding, lA.patches.shape, unstable_idx=lA.unstable_idx, output_shape=lA.output_shape)
- elif type(A) is torch.Tensor:
- if uA is not None:
- # print("uA", uA.shape, self.masked_beta.shape)
- # uA/lA has shape (spec, batch, *nodes)
- if beta_for_intermediate_layers:
- if not self.single_beta_used:
- # masked_beta_upper has shape (batch, spec, #nodes)
- self.masked_beta_upper = self.masked_beta_upper.transpose(0, 1)
- self.masked_beta_upper = self.masked_beta_upper.view(self.masked_beta_upper.size(0),
- self.masked_beta_upper.size(1),
- *uA.shape[2:])
- else:
- # masked_beta_upper has shape (batch, #nodes)
- self.masked_beta_upper = self.masked_beta_upper.reshape(uA[0].shape).unsqueeze(0)
- if not self.single_beta_used or not beta_for_intermediate_layers:
- # For intermediate layer betas witn single node split, uA has been modified above.
- uA = uA + self.masked_beta_upper
- if lA is not None:
- # print("lA", lA.shape, self.masked_beta.shape)
- if beta_for_intermediate_layers:
- if not self.single_beta_used:
- # masked_beta_upper has shape (batch, spec, #nodes)
- self.masked_beta_lower = self.masked_beta_lower.transpose(0, 1)
- self.masked_beta_lower = self.masked_beta_lower.view(self.masked_beta_lower.size(0),
- self.masked_beta_lower.size(1),
- *lA.shape[2:])
- else:
- # masked_beta_upper has shape (batch, #nodes)
- self.masked_beta_lower = self.masked_beta_lower.reshape(lA[0].shape).unsqueeze(0)
- if not self.single_beta_used or not beta_for_intermediate_layers:
- # For intermediate layer betas witn single node split, lA has been modified above.
- lA = lA - self.masked_beta_lower
- else:
- raise RuntimeError(f"Unknown type {type(A)} for A")
- # print("total:", time.time()-start_time, history_compute_time1, history_compute_time2, split_convert_time, split_compute_time)
+ if self.cut_used:
+ # propagate prerelu node in cut constraints
+ lA, uA = self.cut_module.pre_cut(start_node, self.name, lA, uA, current_layer_shape, unstable_idx,
+ batch_mask=self.alpha_beta_update_mask)
+ self.masked_beta_lower = self.masked_beta_upper = None
+ if self.options.get('optimize_bound_args', {}).get('enable_beta_crown', False) and self.sparse_beta is not None:
+ if self.options.get('optimize_bound_args', {}).get('single_node_split', False):
+ # Beta-CROWN: each split constraint only has single neuron (e.g., second ReLU neuron > 0).
+ A = lA if lA is not None else uA
+ lA, uA = _beta_crown_single_neuron_splits(A, uA, lA, unstable_idx)
+ # The code block below is for debugging and will be removed (until the end of this function).
+ # elif False and not self.options.get('optimize_bound_args', {}).get('single_node_split', True):
+ # # Improved Beta-CROWN: (1) general split constraints: each split constraint have multiple neuron
+ # # (e.g., second ReLU neuron > 0); (2) intermediate Relu bounds refinement with the general split constraints.
+ # A = uA if uA is not None else lA
+ # lA, uA = _beta_crown_multi_neuron_splits(x, A, uA, lA, unstable_idx, start_node)
+ # print(lA.sum(), uA.sum())
+ # exit()
return [(lA, uA)], lbias, ubias
def interval_propagate(self, *v):
- if Interval.use_relative_bounds(*v):
- nominal = F.relu(v[0].nominal)
- mask_nominal = (nominal > 0).float()
- mask_l = (v[0].lower > 0).float()
- mask_u = (v[0].upper > 0).float()
- lower_offset = mask_nominal * (mask_l * v[0].lower_offset + (1 - mask_l) * (-nominal))
- upper_offset = mask_nominal * v[0].upper_offset + (1 - mask_nominal) * mask_u * v[0].upper
- return Interval(None, None, nominal, lower_offset, upper_offset)
-
h_L, h_U = v[0][0], v[0][1]
-
return F.relu(h_L), F.relu(h_U)
- def bound_forward(self, dim_in, x):
- return super().bound_forward(dim_in, x)
-
-class BoundSqrt(BoundActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.nonlinear = True
-
- @Bound.save_io_shape
- def forward(self, x):
- return torch.sqrt(x)
-
- def interval_propagate(self, *v):
- if Interval.use_relative_bounds(*v):
- nominal = self.forward(v[0].nominal)
- lower_offset = self.forward(v[0].nominal + v[0].lower_offset) - nominal
- upper_offset = self.forward(v[0].nominal + v[0].upper_offset) - nominal
- return Interval(None, None, nominal, lower_offset, upper_offset)
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ # e.g., last layer input gurobi vars (8,16,16)
+ gvars_array = np.array(v[0])
+ this_layer_shape = gvars_array.shape
+ assert gvars_array.shape == self.output_shape[1:]
+
+ pre_lbs = self.inputs[0].lower.cpu().detach().numpy().reshape(-1)
+ pre_ubs = self.inputs[0].upper.cpu().detach().numpy().reshape(-1)
+
+ new_layer_gurobi_vars = []
+ relu_integer_vars = []
+ new_relu_layer_constrs = []
+ # predefined zero variable shared in the whole solver model
+ zero_var = model.getVarByName("zero")
+
+ for neuron_idx, pre_var in enumerate(gvars_array.reshape(-1)):
+ pre_ub = pre_ubs[neuron_idx]
+ pre_lb = pre_lbs[neuron_idx]
+
+ if pre_lb >= 0:
+ # ReLU is always passing
+ var = pre_var
+ elif pre_ub <= 0:
+ var = zero_var
+ else:
+ ub = pre_ub
+
+ var = model.addVar(ub=ub, lb=pre_lb,
+ obj=0,
+ vtype=grb.GRB.CONTINUOUS,
+ name=f'ReLU{self.name}_{neuron_idx}')
+
+ if model_type == "mip" or model_type == "lp_integer":
+ # binary indicator
+ if model_type == "mip":
+ a = model.addVar(vtype=grb.GRB.BINARY, name=f'aReLU{self.name}_{neuron_idx}')
+ elif model_type == "lp_integer":
+ a = model.addVar(ub=1, lb=0, vtype=grb.GRB.CONTINUOUS, name=f'aReLU{self.name}_{neuron_idx}')
+ relu_integer_vars.append(a)
+
+ new_relu_layer_constrs.append(
+ model.addConstr(pre_var - pre_lb * (1 - a) >= var,
+ name=f'ReLU{self.name}_{neuron_idx}_a_0'))
+ new_relu_layer_constrs.append(
+ model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_1'))
+ new_relu_layer_constrs.append(
+ model.addConstr(pre_ub * a >= var, name=f'ReLU{self.name}_{neuron_idx}_a_2'))
+ new_relu_layer_constrs.append(
+ model.addConstr(var >= 0, name=f'ReLU{self.name}_{neuron_idx}_a_3'))
+
+ elif model_type == "lp":
+ new_relu_layer_constrs.append(
+ model.addConstr(var >= 0, name=f'ReLU{self.name}_{neuron_idx}_a_0'))
+ new_relu_layer_constrs.append(
+ model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_1'))
+ new_relu_layer_constrs.append(model.addConstr(
+ pre_ub * pre_var - (pre_ub - pre_lb) * var >= pre_ub * pre_lb,
+ name=f'ReLU{self.name}_{neuron_idx}_a_2'))
- return super().interval_propagate(*v)
+ else:
+ print(f"gurobi model type {model_type} not supported!")
-class BoundReciprocal(BoundActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.nonlinear = True
+ new_layer_gurobi_vars.append(var)
- @Bound.save_io_shape
- def forward(self, x):
- return torch.reciprocal(x)
+ new_layer_gurobi_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape).tolist()
+ if model_type in ["mip", "lp_integer"]:
+ self.integer_vars = relu_integer_vars
+ self.solver_vars = new_layer_gurobi_vars
+ self.solver_constrs = new_relu_layer_constrs
+ model.update()
- def bound_relax(self, x):
- m = (x.lower + x.upper) / 2
- kl = -1 / m.pow(2)
- self._add_linear(mask=None, type='lower', k=kl, x0=m, y0=1. / m)
- ku = -1. / (x.lower * x.upper)
- self._add_linear(mask=None, type='upper', k=ku, x0=x.lower, y0=1. / x.lower)
+ def dump_optimized_params(self):
+ return {
+ 'alpha': self.alpha,
+ 'alpha_lookup_idx': self.alpha_lookup_idx,
+ 'alpha_indices': self.alpha_indices
+ }
- def interval_propagate(self, *v):
- h_L, h_U = v[0][0].float(), v[0][1].float()
- assert h_L.min() > 0, 'Only positive values are supported in BoundReciprocal'
- return torch.reciprocal(h_U), torch.reciprocal(h_L)
+ def restore_optimized_params(self, opt_var_dict):
+ self.alpha, self.alpha_lookup_idx, self.alpha_indices = \
+ opt_var_dict['alpha'], opt_var_dict['alpha_lookup_idx'], opt_var_dict['alpha_indices']
-class BoundSin(BoundActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.max_point = math.pi / 2
- self.min_point = math.pi * 3 / 2
+class BoundLeakyRelu(BoundActivation):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.options = options.get('relu')
+ self.alpha = attr['alpha']
- @Bound.save_io_shape
def forward(self, x):
- return torch.sin(x)
-
- def interval_propagate(self, *v):
- # Check if a point is in [l, u], considering the 2pi period
- def check_crossing(ll, uu, point):
- return ((((uu - point) / (2 * math.pi)).floor() - ((ll - point) / (2 * math.pi)).floor()) > 0).float()
- h_L, h_U = v[0][0], v[0][1]
- h_Ls, h_Us = self.forward(h_L), self.forward(h_U)
- # If crossing pi/2, then max is fixed 1.0
- max_mask = check_crossing(h_L, h_U, self.max_point)
- # If crossing pi*3/2, then min is fixed -1.0
- min_mask = check_crossing(h_L, h_U, self.min_point)
- ub = torch.max(h_Ls, h_Us)
- ub = max_mask + (1 - max_mask) * ub
- lb = torch.min(h_Ls, h_Us)
- lb = - min_mask + (1 - min_mask) * lb
- return lb, ub
+ return F.leaky_relu(x, negative_slope=self.alpha)
- def bound_backward(self, last_lA, last_uA, *x, start_node=None, start_shape=None):
- return not_implemented_op(self, 'bound_backward')
+ def bound_backward(self, last_lA, last_uA, x=None, start_node=None, start_shape=None):
+ if x is not None:
+ lb_r = x.lower.clamp(max=0)
+ ub_r = x.upper.clamp(min=0)
+ else:
+ lb_r = self.lower.clamp(max=0)
+ ub_r = self.upper.clamp(min=0)
+ ub_r = torch.max(ub_r, lb_r + 1e-8)
+ upper_d = (ub_r - self.alpha * lb_r) / (ub_r - lb_r)
+ upper_b = - lb_r * upper_d + self.alpha * lb_r
+ if self.options == "same-slope":
+ # the same slope for upper and lower
+ lower_d = upper_d
+ elif self.options == "zero-lb":
+ # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN
+ lower_d = (upper_d >= 1.0).to(upper_d.dtype) + (upper_d < 1.0).to(upper_d.dtype) * self.alpha
+ elif self.options == "one-lb":
+ # Always use slope 1 as lower bound
+ lower_d = (upper_d > 0.0).to(upper_d.dtype)+ (upper_d <= 0.0).to(upper_d.dtype) * self.alpha
+ else:
+ lower_d = (upper_d > 0.5).to(upper_d.dtype) + (upper_d <= 0.5).to(upper_d.dtype)* self.alpha
-class BoundCos(BoundSin):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.max_point = 0.0
- self.min_point = math.pi
+ upper_d = upper_d.unsqueeze(0)
+ lower_d = lower_d.unsqueeze(0)
+ # Choose upper or lower bounds based on the sign of last_A
+ uA = lA = None
+ ubias = lbias = 0
+ if last_uA is not None:
+ neg_uA = last_uA.clamp(max=0)
+ pos_uA = last_uA.clamp(min=0)
+ uA = upper_d * pos_uA + lower_d * neg_uA
+ ubias = self.get_bias(pos_uA, upper_b)
+ if last_lA is not None:
+ neg_lA = last_lA.clamp(max=0)
+ pos_lA = last_lA.clamp(min=0)
+ lA = upper_d * neg_lA + lower_d * pos_lA
+ lbias = self.get_bias(neg_lA, upper_b)
+ return [(lA, uA)], lbias, ubias
- @Bound.save_io_shape
- def forward(self, x):
- return torch.cos(x)
+ def dump_optimized_params(self):
+ return self.alpha
+ def restore_optimized_params(self, alpha):
+ self.alpha = alpha
class BoundTanh(BoundOptimizableActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.precompute_relaxation('tanh', torch.tanh, self.dtanh)
+ # Alpha dimension is (4, 2, output_shape, batch, *shape) for Tanh.
+ self.alpha_batch_dim = 3
def opt_init(self):
super().opt_init()
- self.tp_both_lower_init = {}
+ self.tp_both_lower_init = {}
self.tp_both_upper_init = {}
def init_opt_parameters(self, start_nodes):
- self.alpha = OrderedDict()
l, u = self.inputs[0].lower, self.inputs[0].upper
shape = l.shape
- for ns, size_s in start_nodes:
+ for ns, size_s, _ in start_nodes:
+ if isinstance(size_s, torch.Size):
+ size_s = prod(size_s)
self.alpha[ns] = torch.empty(4, 2, size_s, *shape, device=l.device)
self.alpha[ns].data[:2] = ((l + u) / 2).unsqueeze(0).expand(2, 2, size_s, *shape)
self.alpha[ns].data[2] = self.tp_both_lower_init[ns].expand(2, size_s, *shape)
@@ -863,15 +986,16 @@ def init_opt_parameters(self, start_nodes):
def dtanh(self, x):
# to avoid bp error when cosh is too large
# cosh(25.0)**2 > 1e21
- mask = torch.lt(torch.abs(x), 25.0).float()
+ mask = torch.lt(torch.abs(x), 25.0).to(x.dtype)
cosh = torch.cosh(mask * x + 1 - mask)
return mask * (1. / cosh.pow(2))
- """Precompute relaxation parameters for tanh and sigmoid"""
-
@torch.no_grad()
- def precompute_relaxation(self, name, func, dfunc):
- self.x_limit = 500
+ def precompute_relaxation(self, name, func, dfunc, x_limit = 500):
+ """
+ This function precomputes the tangent lines that will be used as lower/upper bounds for S-shapes functions.
+ """
+ self.x_limit = x_limit
self.step_pre = 0.01
self.num_points_pre = int(self.x_limit / self.step_pre)
max_iter = 100
@@ -879,35 +1003,49 @@ def precompute_relaxation(self, name, func, dfunc):
logger.debug('Precomputing relaxation for {}'.format(name))
def check_lower(upper, d):
+ """Given two points upper, d (d <= upper), check if the slope at d will be less than f(upper) at upper."""
k = dfunc(d)
+ # Return True if the slope is a lower bound.
return k * (upper - d) + func(d) <= func(upper)
def check_upper(lower, d):
+ """Given two points lower, d (d >= lower), check if the slope at d will be greater than f(lower) at lower."""
k = dfunc(d)
+ # Return True if the slope is a upper bound.
return k * (lower - d) + func(d) >= func(lower)
+ # Given an upper bound point (>=0), find a line that is guaranteed to be a lower bound of this function.
upper = self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device)
r = torch.zeros_like(upper)
+ # Initial guess, the tangent line is at -1.
l = -torch.ones_like(upper)
while True:
+ # Check if the tangent line at the guessed point is an lower bound at f(upper).
checked = check_lower(upper, l).int()
+ # If the initial guess is not smaller enough, then double it (-2, -4, etc).
l = checked * l + (1 - checked) * (l * 2)
- if checked.sum() == l.numel():
+ if checked.sum() == l.numel():
break
+ # Now we have starting point at l, its tangent line is guaranteed to be an lower bound at f(upper).
+ # We want to further tighten this bound by moving it closer to 0.
for t in range(max_iter):
+ # Binary search.
m = (l + r) / 2
checked = check_lower(upper, m).int()
l = checked * m + (1 - checked) * l
r = checked * r + (1 - checked) * m
+ # At upper, a line with slope l is guaranteed to lower bound the function.
self.d_lower = l.clone()
+ # Do the same again:
+ # Given an lower bound point (<=0), find a line that is guaranteed to be an upper bound of this function.
lower = -self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device)
l = torch.zeros_like(upper)
r = torch.ones_like(upper)
while True:
checked = check_upper(lower, r).int()
r = checked * r + (1 - checked) * (r * 2)
- if checked.sum() == l.numel():
+ if checked.sum() == l.numel():
break
for t in range(max_iter):
m = (l + r) / 2
@@ -918,12 +1056,11 @@ def check_upper(lower, d):
logger.debug('Done')
- @Bound.save_io_shape
def forward(self, x):
return torch.tanh(x)
def bound_relax_impl(self, x, func, dfunc):
- # When self.x_limit is large enough, torch.tanh(self.x_limit)=1,
+ # When self.x_limit is large enough, torch.tanh(self.x_limit)=1,
# and thus clipping is valid
lower = x.lower.clamp(min=-self.x_limit)
upper = x.upper.clamp(max=self.x_limit)
@@ -931,101 +1068,134 @@ def bound_relax_impl(self, x, func, dfunc):
min_preact = 1e-6
mask_close = (upper - lower) < min_preact
- k_direct = k = torch.where(mask_close,
- dfunc(upper), (y_u - y_l) / (upper - lower).clamp(min=min_preact))
-
- # Fixed bounds that cannot be optimized
- # upper bound for negative
- self._add_linear(mask=self.mask_neg, type='upper', k=k, x0=lower, y0=y_l)
- # lower bound for positive
- self._add_linear(mask=self.mask_pos, type='lower', k=k, x0=lower, y0=y_l)
-
+ # k_direct is the slope of the line directly connect (lower, func(lower)), (upper, func(upper)).
+ k_direct = k = torch.where(mask_close,
+ dfunc(upper), (y_u - y_l) / (upper - lower).clamp(min=min_preact))
+
+ # Fixed bounds that cannot be optimized. self.mask_neg are the masks for neurons with upper bound <= 0.
+ # Upper bound for the case of input lower bound <= 0, is always the direct line.
+ self.add_linear_relaxation(mask=self.mask_neg, type='upper', k=k, x0=lower, y0=y_l)
+ # Lower bound for the case of input upper bound >= 0, is always the direct line.
+ self.add_linear_relaxation(mask=self.mask_pos, type='lower', k=k, x0=lower, y0=y_l)
+
+ # Indices of neurons with input upper bound >=0, whose optimal slope to lower bound the function was pre-computed.
+ # Note that for neurons with also input lower bound >=0, they will be masked later.
index = torch.max(
torch.zeros(upper.numel(), dtype=torch.long, device=upper.device),
(upper / self.step_pre).to(torch.long).reshape(-1)
) + 1
+ # Lookup the lower bound slope from the pre-computed table.
d_lower = torch.index_select(self.d_lower, 0, index).view(lower.shape)
+ # Indices of neurons with lower bound <=0, whose optimal slope to upper bound the function was pre-computed.
index = torch.max(
torch.zeros(lower.numel(), dtype=torch.long, device=lower.device),
(lower / -self.step_pre).to(torch.long).reshape(-1)
) + 1
- d_upper = torch.index_select(self.d_upper, 0, index).view(upper.shape)
+ d_upper = torch.index_select(self.d_upper, 0, index).view(upper.shape)
- ns = self._start.name
-
- # bound with tangent lines can be optimized
- if self.opt_stage == 'opt':
+ if self.opt_stage in ['opt', 'reuse']:
if not hasattr(self, 'alpha'):
+ # Raise an error if alpha is not created.
self._no_bound_parameters()
+ ns = self._start
# Clipping is done here rather than after `opt.step()` call
- # because it depends on pre-activation bounds
+ # because it depends on pre-activation bounds
self.alpha[ns].data[0, :] = torch.max(torch.min(self.alpha[ns][0, :], upper), lower)
self.alpha[ns].data[1, :] = torch.max(torch.min(self.alpha[ns][1, :], upper), lower)
self.alpha[ns].data[2, :] = torch.min(self.alpha[ns][2, :], d_lower)
self.alpha[ns].data[3, :] = torch.max(self.alpha[ns][3, :], d_upper)
+ # shape [2, out_c, n, c, h, w].
tp_pos = self.alpha[ns][0]
tp_neg = self.alpha[ns][1]
tp_both_lower = self.alpha[ns][2]
- tp_both_upper = self.alpha[ns][3]
+ tp_both_upper = self.alpha[ns][3]
# No need to use tangent line, when the tangent point is at the left
# side of the preactivation lower bound. Simply connect the two sides.
- mask_direct = self.mask_both * ( k_direct < dfunc(lower) )
- self._add_linear(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l)
- self._add_linear(mask=self.mask_both - mask_direct, type='lower',
- k=dfunc(tp_both_lower), x0=tp_both_lower,
+ mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(lower))
+ self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l)
+ self.add_linear_relaxation(
+ mask=torch.logical_xor(self.mask_both, mask_direct), type='lower',
+ k=dfunc(tp_both_lower), x0=tp_both_lower,
y0=self.forward(tp_both_lower))
- mask_direct = self.mask_both * ( k_direct < dfunc(upper) )
- self._add_linear(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l)
- self._add_linear(mask=self.mask_both - mask_direct, type='upper',
- k=dfunc(tp_both_upper), x0=tp_both_upper,
+ mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(upper))
+ self.add_linear_relaxation(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l)
+ self.add_linear_relaxation(
+ mask=torch.logical_xor(self.mask_both, mask_direct), type='upper',
+ k=dfunc(tp_both_upper), x0=tp_both_upper,
y0=self.forward(tp_both_upper))
- self._add_linear(mask=self.mask_neg, type='lower',
+ self.add_linear_relaxation(
+ mask=self.mask_neg, type='lower',
k=dfunc(tp_neg), x0=tp_neg, y0=self.forward(tp_neg))
- self._add_linear(mask=self.mask_pos, type='upper',
+ self.add_linear_relaxation(
+ mask=self.mask_pos, type='upper',
k=dfunc(tp_pos), x0=tp_pos, y0=self.forward(tp_pos))
else:
+ # Not optimized (vanilla CROWN bound).
+ # Use the middle point slope as the lower/upper bound. Not optimized.
m = (lower + upper) / 2
y_m = func(m)
k = dfunc(m)
- # lower bound for negative
- self._add_linear(mask=self.mask_neg, type='lower', k=k, x0=m, y0=y_m)
- # upper bound for positive
- self._add_linear(mask=self.mask_pos, type='upper', k=k, x0=m, y0=y_m)
-
+ # Lower bound is the middle point slope for the case input upper bound <= 0.
+ # Note that the upper bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)).
+ self.add_linear_relaxation(mask=self.mask_neg, type='lower', k=k, x0=m, y0=y_m)
+ # Upper bound is the middle point slope for the case input lower bound >= 0.
+ # Note that the lower bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)).
+ self.add_linear_relaxation(mask=self.mask_pos, type='upper', k=k, x0=m, y0=y_m)
+
+ # Now handle the case where input lower bound <=0 and upper bound >= 0.
+ # A tangent line starting at d_lower is guaranteed to be a lower bound given the input upper bound.
k = dfunc(d_lower)
y0 = func(d_lower)
if self.opt_stage == 'init':
+ # Initialize optimizable slope.
+ ns = self._start
self.tp_both_lower_init[ns] = d_lower.detach()
- mask_direct = self.mask_both * ( k_direct < dfunc(lower) )
- self._add_linear(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l)
- self._add_linear(mask=self.mask_both - mask_direct, type='lower', k=k, x0=d_lower, y0=y0)
-
+ # Another possibility is to use the direct line as the lower bound, when this direct line does not intersect with f.
+ # This is only valid when the slope at the input lower bound has a slope greater than the direct line.
+ mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(lower))
+ self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l)
+ # Otherwise we do not use the direct line, we use the d_lower slope.
+ self.add_linear_relaxation(
+ mask=torch.logical_xor(self.mask_both, mask_direct),
+ type='lower', k=k, x0=d_lower, y0=y0)
+
+ # Do the same for the upper bound side when input lower bound <=0 and upper bound >= 0.
k = dfunc(d_upper)
y0 = func(d_upper)
if self.opt_stage == 'init':
- self.tp_both_upper_init[ns] = d_upper.detach()
+ ns = self._start
+ self.tp_both_upper_init[ns] = d_upper.detach()
self.tmp_lower = x.lower.detach()
self.tmp_upper = x.upper.detach()
- mask_direct = self.mask_both * ( k_direct < dfunc(upper) )
- self._add_linear(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l)
- self._add_linear(mask=self.mask_both - mask_direct, type='upper', k=k, x0=d_upper, y0=y0)
+ mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(upper))
+ self.add_linear_relaxation(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l)
+ self.add_linear_relaxation(
+ mask=torch.logical_xor(self.mask_both, mask_direct),
+ type='upper', k=k, x0=d_upper, y0=y0)
def bound_relax(self, x):
self.bound_relax_impl(x, torch.tanh, self.dtanh)
+ def dump_optimized_params(self):
+ return self.alpha
+
+ def restore_optimized_params(self, alpha):
+ self.alpha = alpha
+
class BoundSigmoid(BoundTanh):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super(BoundTanh, self).__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super(BoundTanh, self).__init__(attr, inputs, output_index, options)
self.precompute_relaxation('sigmoid', torch.sigmoid, self.dsigmoid)
+ # Alpha dimension is (4, 2, output_shape, batch, *shape) for S-shaped functions.
+ self.alpha_batch_dim = 3
- @Bound.save_io_shape
def forward(self, x):
return torch.sigmoid(x)
@@ -1037,196 +1207,125 @@ def bound_relax(self, x):
class BoundSoftplus(BoundActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super(BoundSoftplus, self).__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super(BoundSoftplus, self).__init__(attr, inputs, output_index, options)
self.softplus = nn.Softplus()
- @Bound.save_io_shape
def forward(self, x):
- return self.softplus(x)
+ return self.softplus(x)
-class BoundExp(BoundActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.options = options.get('exp')
- self.max_input = 0
+class BoundAbs(BoundActivation):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
- @Bound.save_io_shape
def forward(self, x):
- if self.loss_fusion and self.options != 'no-max-input':
- self.max_input = torch.max(x, dim=-1, keepdim=True)[0].detach()
- return torch.exp(x - self.max_input)
- return torch.exp(x)
-
- def interval_propagate(self, *v):
- assert (len(v) == 1)
-
- if Interval.use_relative_bounds(*v):
- assert not self.loss_fusion or self.options == 'no-max-input'
- nominal = torch.exp(v[0].nominal)
- return Interval(
- None, None,
- nominal,
- nominal * (torch.exp(v[0].lower_offset) - 1),
- nominal * (torch.exp(v[0].upper_offset) - 1)
- )
-
- # unary monotonous functions only
- h_L, h_U = v[0]
- if self.loss_fusion and self.options != 'no-max-input':
- self.max_input = torch.max(h_U, dim=-1, keepdim=True)[0]
- h_L, h_U = h_L - self.max_input, h_U - self.max_input
- else:
- self.max_input = 0
- return torch.exp(h_L), torch.exp(h_U)
-
- def bound_forward(self, dim_in, x):
- m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99)
-
- exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper)
-
- kl = exp_m
- lw = x.lw * kl.unsqueeze(1)
- lb = kl * (x.lb - m + 1)
-
- ku = (exp_u - exp_l) / (x.upper - x.lower + epsilon)
- uw = x.uw * ku.unsqueeze(1)
- ub = x.ub * ku - ku * x.lower + exp_l
+ return x.abs()
- return LinearBound(lw, lb, uw, ub)
-
- def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None):
- # Special case when computing log_softmax (FIXME: find a better solution, this trigger condition is not reliable).
- if self.loss_fusion and last_lA is None and last_uA is not None and torch.min(
- last_uA) >= 0 and x.from_input:
- # Adding an extra bias term to the input. This is equivalent to adding a constant and subtract layer before exp.
- # Note that we also need to adjust the bias term at the end.
- if self.options == 'no-detach':
- self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0]
- elif self.options != 'no-max-input':
- self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0].detach()
+ def bound_backward(self, last_lA, last_uA, x):
+ x_L = x.lower.clamp(max=0)
+ x_U = torch.max(x.upper.clamp(min=0), x_L + 1e-8)
+ mask_neg = x_U <= 0
+ mask_pos = x_L >= 0
+ y_L = x_L.abs()
+ y_U = x_U.abs()
+ upper_k = (y_U - y_L) / (x_U - x_L)
+ upper_b = y_L - upper_k * x_L
+ lower_k = (mask_neg * (-1.0) + mask_pos * 1.0)
+ lower_b = (mask_neg + mask_pos) * ( y_L - lower_k * x_L )
+ if last_uA is not None:
+ # Special case if we only want the upper bound with non-negative coefficients
+ if last_uA.min() >= 0:
+ uA = last_uA * upper_k
+ ubias = self.get_bias(last_uA, upper_b)
else:
- self.max_input = 0
- adjusted_lower = x.lower - self.max_input
- adjusted_upper = x.upper - self.max_input
- # relaxation for upper bound only (used in loss fusion)
- exp_l, exp_u = torch.exp(adjusted_lower), torch.exp(adjusted_upper)
- k = (exp_u - exp_l) / (adjusted_upper - adjusted_lower + epsilon)
- if k.requires_grad:
- k = k.clamp(min=1e-6)
- uA = last_uA * k.unsqueeze(0)
- ubias = last_uA * (-adjusted_lower * k + exp_l).unsqueeze(0)
-
- if ubias.ndim > 2:
- ubias = torch.sum(ubias, dim=tuple(range(2, ubias.ndim)))
- # Also adjust the missing ubias term.
- if uA.ndim > self.max_input.ndim:
- A = torch.sum(uA, dim=tuple(range(self.max_input.ndim, uA.ndim)))
+ last_uA_pos = last_uA.clamp(min=0)
+ last_uA_neg = last_uA.clamp(max=0)
+ uA = last_uA_pos * upper_k + last_uA_neg * lower_k
+ ubias = (self.get_bias(last_uA_pos, upper_b)
+ + self.get_bias(last_uA_neg, lower_b))
+ else:
+ uA, ubias = None, 0
+ if last_lA is not None:
+ if last_lA.max() <= 0:
+ lA = last_lA * upper_k
+ lbias = self.get_bias(last_lA, upper_b)
else:
- A = uA
-
- # These should hold true in loss fusion
- assert self.batch_dim == 0
- assert A.shape[0] == 1
-
- batch_size = A.shape[1]
- ubias -= (A.reshape(batch_size, -1) * self.max_input.reshape(batch_size, -1)).sum(dim=-1).unsqueeze(0)
- return [(None, uA)], 0, ubias
+ last_lA_pos = last_lA.clamp(min=0)
+ last_lA_neg = last_lA.clamp(max=0)
+ lA = last_lA_pos * lower_k + last_lA_neg * upper_k
+ lbias = (self.get_bias(last_lA_pos, lower_b)
+ + self.get_bias(last_lA_neg, upper_b))
else:
- return super().bound_backward(last_lA, last_uA, x)
+ lA, lbias = None, 0
+ return [(lA, uA)], lbias, ubias
- def bound_relax(self, x):
- min_val = -1e9
- l, u = x.lower.clamp(min=min_val), x.upper.clamp(min=min_val)
- m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99)
- exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper)
- k = exp_m
- self._add_linear(mask=None, type='lower', k=k, x0=m, y0=exp_m)
- min_val = -1e9 # to avoid (-inf)-(-inf) when both input.lower and input.upper are -inf
- epsilon = 1e-20
- close = (u - l < epsilon).int()
- k = close * exp_u + (1 - close) * (exp_u - exp_l) / (u - l + epsilon)
- self._add_linear(mask=None, type='upper', k=k, x0=l, y0=exp_l)
-
-
-class BoundLog(BoundActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.nonlinear = True
-
- @Bound.save_io_shape
- def forward(self, x):
- # NOTE adhoc implementation for loss fusion
- if self.loss_fusion:
- return torch.logsumexp(self.inputs[0].inputs[0].inputs[0].forward_value, dim=-1)
- return torch.log(x.clamp(min=epsilon))
+ def interval_propagate(self, *v):
+ h_L, h_U = v[0][0], v[0][1]
+ lower = ((h_U < 0) * h_U.abs() + (h_L > 0) * h_L.abs())
+ upper = torch.max(h_L.abs(), h_U.abs())
+ return lower, upper
- def bound_relax(self, x):
- rl, ru = self.forward(x.lower), self.forward(x.upper)
- ku = (ru - rl) / (x.upper - x.lower + epsilon)
- self._add_linear(mask=None, type='lower', k=ku, x0=x.lower, y0=rl)
- m = (x.lower + x.upper) / 2
- k = torch.reciprocal(m)
- rm = self.forward(m)
- self._add_linear(mask=None, type='upper', k=k, x0=m, y0=rm)
- def interval_propagate(self, *v):
- # NOTE adhoc implementation for loss fusion
- if self.loss_fusion:
- par = self.inputs[0].inputs[0].inputs[0]
- if Interval.use_relative_bounds(*v):
- lower = torch.logsumexp(par.interval.nominal + par.interval.lower_offset, dim=-1)
- upper = torch.logsumexp(par.interval.nominal + par.interval.upper_offset, dim=-1)
- return Interval.make_interval(lower, upper, nominal=self.forward_value, use_relative=True)
- else:
- lower = torch.logsumexp(par.lower, dim=-1)
- upper = torch.logsumexp(par.upper, dim=-1)
- return lower, upper
- return super().interval_propagate(*v)
+class BoundATenHeaviside(BoundOptimizableActivation):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.alpha_batch_dim = 2
- def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None):
- A, lbias, ubias = super().bound_backward(last_lA, last_uA, x)
- # NOTE adhoc implementation for loss fusion
- if self.loss_fusion:
- assert A[0][0] is None
- exp_module = self.inputs[0].inputs[0]
- ubias = ubias + self.get_bias(A[0][1], exp_module.max_input.squeeze(-1))
- return A, lbias, ubias
+ def forward(self, *x):
+ self.input_shape = x[0].shape
+ # x[0]: input; x[1]: value when x == 0
+ return torch.heaviside(x[0], x[1])
+ def init_opt_parameters(self, start_nodes):
+ l = self.inputs[0].lower
+ for ns, size_s, _ in start_nodes:
+ self.alpha[ns] = torch.zeros_like(l).unsqueeze(0).repeat(2, *[1] * l.ndim).requires_grad_(True)
-class BoundPow(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.nonlinear = True
+ def clip_alpha_(self):
+ for v in self.alpha.values():
+ v.data = torch.clamp(v.data, 0., 1.)
- @Bound.save_io_shape
- def forward(self, x, y):
- return torch.pow(x, y)
+ def bound_backward(self, last_lA, last_uA, *x, start_node=None, start_shape=None):
+ x = x[0]
+ if x is not None:
+ lb_r = x.lower
+ ub_r = x.upper
+ else:
+ lb_r = self.lower
+ ub_r = self.upper
- def interval_propagate(self, *v):
- assert not self.is_input_perturbed(1)
-
- if Interval.use_relative_bounds(*v):
- exp = v[1].nominal
- assert exp == int(exp)
- exp = int(exp)
- h_L = v[0].nominal + v[0].lower_offset
- h_U = v[0].nominal + v[0].upper_offset
- lower, upper = torch.pow(h_L, exp), torch.pow(h_U, exp)
- if exp % 2 == 0:
- lower, upper = torch.min(lower, upper), torch.max(lower, upper)
- mask = 1 - ((h_L < 0) * (h_U > 0)).float()
- lower = lower * mask
- return Interval.make_interval(lower, upper, nominal=self.forward_value, use_relative=True)
-
- exp = v[1][0]
- assert exp == int(exp)
- exp = int(exp)
- pl, pu = torch.pow(v[0][0], exp), torch.pow(v[0][1], exp)
- if exp % 2 == 1:
- return pl, pu
+ if self.opt_stage not in ['opt', 'reuse']:
+ # zero slope:
+ upper_d = torch.zeros_like(lb_r, device=lb_r.device, dtype=lb_r.dtype)
+ lower_d = torch.zeros_like(ub_r, device=ub_r.device, dtype=ub_r.dtype)
else:
- pl, pu = torch.min(pl, pu), torch.max(pl, pu)
- mask = 1 - ((v[0][0] < 0) * (v[0][1] > 0)).float()
- return pl * mask, pu
+ upper_d = self.alpha[start_node.name][0].clamp(0, 1) * (1. / (-lb_r).clamp(min=1e-3))
+ lower_d = self.alpha[start_node.name][1].clamp(0, 1) * (1. / (ub_r.clamp(min=1e-3)))
+
+ upper_b = torch.ones_like(lb_r, device=lb_r.device, dtype=lb_r.dtype)
+ lower_b = torch.zeros_like(lb_r, device=lb_r.device, dtype=lb_r.dtype)
+ # For stable neurons, set fixed slope and bias.
+ ub_mask = (ub_r <= 0).to(dtype=ub_r.dtype)
+ lb_mask = (lb_r >= 0).to(dtype=lb_r.dtype)
+ upper_b = upper_b - upper_b * ub_mask
+ lower_b = lower_b * (1. - lb_mask) + lb_mask
+ upper_d = upper_d - upper_d * ub_mask - upper_d * lb_mask
+ lower_d = lower_d - lower_d * lb_mask - lower_d * ub_mask
+ upper_d = upper_d.unsqueeze(0)
+ lower_d = lower_d.unsqueeze(0)
+ # Choose upper or lower bounds based on the sign of last_A
+ uA = lA = None
+ ubias = lbias = 0
+ if last_uA is not None:
+ neg_uA = last_uA.clamp(max=0)
+ pos_uA = last_uA.clamp(min=0)
+ uA = upper_d * pos_uA + lower_d * neg_uA
+ ubias = (pos_uA * upper_b + neg_uA * lower_b).flatten(2).sum(-1)
+ if last_lA is not None:
+ neg_lA = last_lA.clamp(max=0)
+ pos_lA = last_lA.clamp(min=0)
+ lA = upper_d * neg_lA + lower_d * pos_lA
+ lbias = (pos_lA * lower_b + neg_lA * upper_b).flatten(2).sum(-1)
+
+ return [(lA, uA), (None, None)], lbias, ubias
diff --git a/auto_LiRPA/operators/base.py b/auto_LiRPA/operators/base.py
index 2062116..b16d813 100644
--- a/auto_LiRPA/operators/base.py
+++ b/auto_LiRPA/operators/base.py
@@ -3,22 +3,27 @@
import os
import time
import math
+import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
+from torch import Tensor
import numpy as np
from itertools import chain
from numpy.lib.arraysetops import isin
from collections import OrderedDict
-from auto_LiRPA.perturbations import *
-from auto_LiRPA.utils import *
+from ..perturbations import *
+from ..utils import *
+from ..patches import *
+from ..linear_bound import LinearBound
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
epsilon = 1e-12
+
def not_implemented_op(node, func):
message = ("Function `{}` of `{}` is not supported yet."
" Please help to open an issue at https://github.com/KaidiXu/auto_LiRPA"
@@ -26,26 +31,21 @@ def not_implemented_op(node, func):
" or auto_LiRPA/operators by yourself.".format(func, node))
raise NotImplementedError(message)
-"""Interval object. Used for interval bound propagation."""
+
class Interval(tuple):
+ """Interval object for interval bound propagation."""
+
# Subclassing tuple object so that all previous code can be reused.
- def __new__(self, lb=None, ub=None, nominal=None, lower_offset=None, upper_offset=None, ptb=None):
+ def __new__(self, lb=None, ub=None, ptb=None):
return tuple.__new__(Interval, (lb, ub))
- def __init__(self, lb, ub, nominal=None, lower_offset=None, upper_offset=None, ptb=None):
- self.nominal = nominal
- self.lower_offset = lower_offset
- self.upper_offset = upper_offset
-
+ def __init__(self, lb, ub, ptb=None):
if ptb is None:
self.ptb = None
- # If relative bounds are not used, `self.ptb == None` means that this interval
+ # `self.ptb == None` means that this interval
# is not perturbed and it shall be treated as a constant and lb = ub.
- # But if relative bounds are used, every node in IBP is supposed to have an `Interval` object
- # even if this node is perturbed.
- if nominal is None:
- # To avoid mistakes, in this case the caller must make sure lb and ub are the same object.
- assert lb is ub
+ # To avoid mistakes, in this case the caller must make sure lb and ub are the same object.
+ assert lb is ub
else:
if not isinstance(ptb, Perturbation):
raise ValueError("ptb must be a Perturbation object or None. Got type {}".format(type(ptb)))
@@ -58,34 +58,17 @@ def __str__(self):
def __repr__(self):
return "Interval(lb={}, ub={}, ptb={})".format(self[0], self[1], self.ptb)
- @property
- def lower(self):
- return self.nominal + self.lower_offset
-
- @property
- def upper(self):
- return self.nominal + self.upper_offset
-
- """Checking if the other interval is tuple, keep the perturbation."""
-
@staticmethod
- def make_interval(lb, ub, other=None, nominal=None, use_relative=False):
+ def make_interval(lb, ub, other=None):
+ """Checking if the other interval is tuple, keep the perturbation."""
if isinstance(other, Interval):
return Interval(lb, ub, ptb=other.ptb)
else:
- if use_relative:
- if nominal is None:
- return Interval(
- None, None, (lb + ub) / 2, (lb - ub) / 2, (ub - lb) / 2)
- else:
- return Interval(None, None, nominal, lb - nominal, ub - nominal)
- else:
- return (lb, ub)
-
- """Given a tuple or Interval object, returns the norm and eps."""
+ return (lb, ub)
@staticmethod
def get_perturbation(interval):
+ """Given a tuple or Interval object, returns the norm and eps."""
if isinstance(interval, Interval) and interval.ptb is not None:
if isinstance(interval.ptb, PerturbationLpNorm):
return interval.ptb.norm, interval.ptb.eps
@@ -93,33 +76,21 @@ def get_perturbation(interval):
return np.inf, 1.0
elif isinstance(interval.ptb, PerturbationL0Norm):
return 0, interval.ptb.eps, interval.ptb.ratio
- # elif interval.ptb is None:
- # raise RuntimeError("get_perturbation() encountered an interval that is not perturbed.")
else:
raise RuntimeError("get_perturbation() does not know how to handle {}".format(type(interval.ptb)))
else:
# Tuple object. Assuming L infinity norm lower and upper bounds.
return np.inf, np.nan
- """Checking if a Interval or tuple object has perturbation enabled."""
@staticmethod
def is_perturbed(interval):
+ """Checking if a Interval or tuple object has perturbation enabled."""
if isinstance(interval, Interval) and interval.ptb is None:
return False
else:
return True
- @staticmethod
- def use_relative_bounds(*intervals):
- using = True
- for interval in intervals:
- using = using and (
- isinstance(interval, Interval) and
- interval.nominal is not None and
- interval.lower_offset is not None and interval.upper_offset is not None)
- return using
-
class Bound(nn.Module):
r"""
@@ -127,12 +98,6 @@ class Bound(nn.Module):
at `auto_LiRPA/operators`.
Args:
- input_name (list): The name of input nodes.
-
- name (str): The name of this node.
-
- ori_name (str): Name in the original model.
-
attr (dict): Attributes of the operator.
inputs (list): A list of input nodes.
@@ -141,18 +106,22 @@ class Bound(nn.Module):
options (dict): Bound options.
- device (str or torch.device): Device of the bounded module.
-
- Be sure to run `super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)`
+ Be sure to run `super().__init__(attr, inputs, output_index, options, device)`
first in the `__init__` function.
"""
- def __init__(self, input_name, name, ori_name, attr={}, inputs=[], output_index=0, options={}, device=None):
+ def __init__(self, attr=None, inputs=None, output_index=0, options=None):
super().__init__()
+ attr = {} if attr is None else attr
+ inputs = [] if inputs is None else inputs
+ options = {} if options is None else options
+ self.name = None
self.output_name = []
- self.input_name, self.name, self.ori_name, self.attr, self.inputs, self.output_index, self.options, self.device = \
- input_name, name, ori_name, attr, inputs, output_index, options, device
+ self.device = attr.get('device')
+ self.attr, self.inputs, self.output_index, self.options = \
+ attr, inputs, output_index, options
self.forward_value = None
+ self.output_shape = None
self.from_input = False
self.bounded = False
self.IBP_rets = None
@@ -164,7 +133,7 @@ def __init__(self, input_name, name, ori_name, attr={}, inputs=[], output_index=
self.loss_fusion = False
self.options = options
# Use `default_interval_propagate`
- self.use_default_ibp = False
+ self.use_default_ibp = False
# If set to true, the backward bound output of this node is 0.
self.zero_backward_coeffs_l = False
self.zero_backward_coeffs_u = False
@@ -172,15 +141,24 @@ def __init__(self, input_name, name, ori_name, attr={}, inputs=[], output_index=
self.zero_lA_mtx = False
self.zero_uA_mtx = False
- """Check if the i-th input is with perturbation or not."""
+ self.patches_start = False
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(name="{self.name}")'
+
def is_input_perturbed(self, i=0):
- return self.inputs[i].perturbed
+ r"""Check if the i-th input is with perturbation or not."""
+ return i < len(self.inputs) and self.inputs[i].perturbed
+
+ def clear(self):
+ """ Clear attributes when there is a new input to the network"""
+ pass
def forward(self, *x):
r"""
- Function for standard/clean forward.
+ Function for standard/clean forward.
- Args:
+ Args:
x: A list of input values. The length of the list is equal to the number of input nodes.
Returns:
@@ -192,37 +170,30 @@ def interval_propagate(self, *v):
r"""
Function for interval bound propagation (IBP) computation.
- There is a default function `self.default_interval_propagate(*v)` in the base class,
+ There is a default function `self.default_interval_propagate(*v)` in the base class,
which can be used if the operator is *monotonic*. To use it, set `self.use_default_ibp = True`
in the `__init__` function, and the implementation of this function can be skipped.
- Args:
- v: A list of the interval bound of input nodes.
+ Args:
+ v: A list of the interval bound of input nodes.
Generally, for each element `v[i]`, `v[i][0]` is the lower interval bound,
and `v[i][1]` is the upper interval bound.
Returns:
bound: The interval bound of this node, in a same format as v[i].
- """
+ """
if self.use_default_ibp:
return self.default_interval_propagate(*v)
else:
return not_implemented_op(self, 'interval_propagate')
-
- """For unary monotonous functions or functions for altering shapes only but not values"""
+
def default_interval_propagate(self, *v):
+ """For unary monotonous functions or functions for altering shapes only but not values"""
if len(v) == 0:
return Interval.make_interval(self.forward(), self.forward())
elif len(v) == 1:
- if Interval.use_relative_bounds(v[0]):
- return Interval(
- None, None,
- self.forward(v[0].nominal),
- self.forward(v[0].lower_offset),
- self.forward(v[0].upper_offset)
- )
- else:
- return Interval.make_interval(self.forward(v[0][0]), self.forward(v[0][1]), v[0])
+ return Interval.make_interval(
+ self.forward(v[0][0]), self.forward(v[0][1]), v[0])
else:
raise NotImplementedError('default_interval_propagate only supports no more than 1 input node')
@@ -230,34 +201,41 @@ def default_interval_propagate(self, *v):
def bound_forward(self, dim_in, *x):
r"""
Function for forward mode bound propagation. Forward mode LiRPA computs a `LinearBound`
- instance representing the linear bound for each involved node. Major attributes of `LinearBound` include
+ instance representing the linear bound for each involved node.
+ Major attributes of `LinearBound` include
`lw`, `uw`, `lb`, `ub`, `lower`, and `upper`.
- `lw` and `uw` are coefficients of linear bounds w.r.t. model input.
- Their shape is `(batch_size, dim_in, *standard_shape)`, where `dim_in` is the total dimension
- of perturbed input nodes of the model, and `standard_shape` is the shape of the standard/clean output.
- `lb` and `ub` are bias terms of linear bounds, and their shape is equal to the shape of standard/clean output.
- `lower` and `upper` are concretized lower and upper bounds that will be computed later in BoundedModule.
+ `lw` and `uw` are coefficients of linear bounds w.r.t. model input.
+ Their shape is `(batch_size, dim_in, *standard_shape)`,
+ where `dim_in` is the total dimension of perturbed input nodes of the model,
+ and `standard_shape` is the shape of the standard/clean output.
+ `lb` and `ub` are bias terms of linear bounds, and their shape is equal
+ to the shape of standard/clean output.
+ `lower` and `upper` are concretized lower and upper bounds that will be
+ computed later in BoundedModule.
- Args:
+ Args:
dim_in (int): Total dimension of perturbed input nodes of the model.
-
+
x: A list of the linear bound of input nodes. Each element in x is a `LinearBound` instance.
Returns:
bound (LinearBound): The linear bound of this node.
- """
+ """
return not_implemented_op(self, 'bound_forward')
+ def bound_dynamic_forward(self, *x, max_dim=None, offset=0):
+ raise NotImplementedError(f'bound_dynamic_forward is not implemented for {self}.')
+
def bound_backward(self, last_lA, last_uA, *x):
r"""
Function for backward mode bound propagation.
- Args:
+ Args:
last_lA (Tensor): `A` matrix for lower bound computation propagated to this node. It can be `None` if lower bound is not needed.
-
+
last_uA (Tensor): `A` matrix for upper bound computation propagated to this node. It can be `None` if upper bound is not needed.
-
+
x: A list of input nodes, with x[i].lower and x[i].upper that can be used as pre-activation bounds.
Returns:
@@ -265,8 +243,8 @@ def bound_backward(self, last_lA, last_uA, *x):
lbias (Tensor): The bias term for lower bound computation, introduced by the linear relaxation of this node. .
- ubias (Tensor): The bias term for upper bound computation, introduced by the linear relaxation of this node.
- """
+ ubias (Tensor): The bias term for upper bound computation, introduced by the linear relaxation of this node.
+ """
return not_implemented_op(self, 'bound_backward')
def infer_batch_dim(self, batch_size, *x):
@@ -278,8 +256,8 @@ def infer_batch_dim(self, batch_size, *x):
def broadcast_backward(self, A, x):
shape = x.output_shape
batch_dim = max(self.batch_dim, 0)
-
- if isinstance(A, torch.Tensor):
+
+ if isinstance(A, Tensor):
if x.batch_dim == -1:
# final shape of input
shape = torch.Size([A.shape[batch_dim + 1]] + list(shape))
@@ -298,6 +276,7 @@ def broadcast_backward(self, A, x):
dims = []
for i in range(len(shape)):
# Skip the batch dimension.
+ # FIXME (05/11/2022): the following condition is not always correct. We should not rely on checking dimension is "1" or not.
if shape[i] == 1 and A.shape[i + 1] != 1 and i != batch_dim:
dims.append(i + 1)
if dims:
@@ -307,30 +286,6 @@ def broadcast_backward(self, A, x):
pass
return A
- @staticmethod
- def broadcast_forward(dim_in, x, shape_res):
- lw, lb, uw, ub = x.lw, x.lb, x.uw, x.ub
- shape_x, shape_res = list(x.lb.shape), list(shape_res)
- if lw is None:
- lw = uw = torch.zeros(dim_in, *shape_x, device=lb.device)
- has_batch_size = False
- else:
- has_batch_size = True
- while len(shape_x) < len(shape_res):
- if not has_batch_size:
- lw, uw = lw.unsqueeze(0), uw.unsqueeze(0)
- lb, ub = lb.unsqueeze(0), ub.unsqueeze(0)
- shape_x = [1] + shape_x
- has_batch_size = True
- else:
- lw, uw = lw.unsqueeze(2), uw.unsqueeze(2)
- lb, ub = lb.unsqueeze(1), ub.unsqueeze(1)
- shape_x = [shape_x[0], 1] + shape_x[1:]
- lb, ub = lb.expand(*shape_res), ub.expand(*shape_res)
- lw = lw.expand(shape_res[0], lw.size(1), *shape_res[1:])
- uw = uw.expand(shape_res[0], uw.size(1), *shape_res[1:])
- return lw, lb, uw, ub
-
def get_bias(self, A, bias):
if A is None:
return 0
@@ -340,8 +295,7 @@ def get_bias(self, A, bias):
if torch.isinf(bias).any():
warnings.warn('There is an inf value in the bias of LiRPA bounds.')
- if isinstance(A, torch.Tensor):
- output_dim = A.shape[0]
+ if isinstance(A, Tensor):
if self.batch_dim != -1:
bias_new = torch.einsum('sb...,b...->sb', A, bias)
else:
@@ -354,14 +308,19 @@ def get_bias(self, A, bias):
else:
# FIXME (09/17): handle the case for pieces.unstable_idx.
return bias_new
+ elif isinstance(A, eyeC):
+ batch_size = A.shape[1]
+ if self.batch_dim != -1:
+ return bias.reshape(batch_size, -1).t()
+ else:
+ return bias.reshape(-1).unsqueeze(-1).repeat(1, batch_size)
elif type(A) == Patches:
# the shape of A.patches is [batch, L, out_c, in_c, K, K]
if self.batch_dim != -1:
# Input A patches has shape (spec, batch, out_h, out_w, in_c, H, W) or (unstable_size, batch, in_c, H, W).
patches = A.patches
-
# Here the size of bias is [batch_size, out_h, out_w, in_c, H, W]
- bias = inplace_unfold(bias, kernel_size=A.patches.shape[-2:], stride=A.stride, padding=A.padding)
+ bias = inplace_unfold(bias, kernel_size=A.patches.shape[-2:], stride=A.stride, padding=A.padding, inserted_zeros=A.inserted_zeros, output_padding=A.output_padding)
if A.unstable_idx is not None:
# Sparse bias has shape [unstable_size, batch_size, in_c, H, W]. No need to select over the out_c dimension.
bias = bias[:, A.unstable_idx[1], A.unstable_idx[2]]
@@ -379,42 +338,24 @@ def get_bias(self, A, bias):
bias_new = torch.sum(patches, dim=(-1, -2, -3)) * bias.to(self.device)
# Return shape is (spec, batch, out_h, out_w) or (unstable_size, batch).
return bias_new
-
return bias_new
else:
return NotImplementedError()
- @staticmethod
- @torch.jit.script
- def clamp_mutiply(A, pos, neg):
- Apos = A.clamp(min=0)
- Aneg = A.clamp(max=0)
- return pos.contiguous() * Apos + neg.contiguous() * Aneg, Apos, Aneg
-
- @staticmethod
- @torch.jit.script
- def clamp_mutiply_non_contiguous(A, pos, neg):
- Apos = A.clamp(min=0)
- Aneg = A.clamp(max=0)
- return pos * Apos + neg * Aneg, Apos, Aneg
-
- """save input and output shapes uniformly by the decorator"""
- @staticmethod
- def save_io_shape(func):
- def wrapper(self, *args, **kwargs):
- if len(args) > 0:
- self.input_shape = args[0].shape # x should always be the first input
-
- output = func(self, *args, **kwargs)
-
- if isinstance(output, torch.Tensor):
- self.output_shape = output.shape
- return output
-
- return wrapper
+ def make_axis_non_negative(self, axis, shape='input'):
+ if shape == 'input':
+ shape = self.input_shape
+ elif shape == 'output':
+ shape = self.output_shape
+ else:
+ assert isinstance(shape, torch.Size)
+ if axis < 0:
+ return axis + len(shape)
+ else:
+ return axis
- """Some operations are non-deterministic and deterministic mode will fail. So we temporary disable it."""
def non_deter_wrapper(self, op, *args, **kwargs):
+ """Some operations are non-deterministic and deterministic mode will fail. So we temporary disable it."""
if self.options.get('deterministic', False):
torch.use_deterministic_algorithms(False)
ret = op(*args, **kwargs)
diff --git a/auto_LiRPA/operators/bivariate.py b/auto_LiRPA/operators/bivariate.py
index 8ec0936..ffda86e 100644
--- a/auto_LiRPA/operators/bivariate.py
+++ b/auto_LiRPA/operators/bivariate.py
@@ -1,14 +1,33 @@
""" Bivariate operators"""
from .base import *
-from .activation import BoundSqrt, BoundReciprocal
+from .nonlinear import BoundSqrt, BoundReciprocal
+from .clampmult import multiply_by_A_signs
+from ..utils import *
+from .solver_utils import grb
+from .constant import BoundConstant
+from .leaf import BoundParams, BoundBuffers
class BoundMul(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.nonlinear = True
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.is_constant_op = False
+ for inp in inputs:
+ if BoundMul._check_const_input(inp):
+ # If any of the two inputs are constant, we do not need input bounds.
+ # FIXME (05/11/2022): this is just a temporary workaround. We need better way to determine whether we need input bounds, not just for BoundConstant.
+ self.is_constant_op = True
+ if self.is_constant_op:
+ # One input is constant; no bounds required.
+ self.requires_input_bounds = []
+ else:
+ # Both inputs are perturbed. Need relaxation.
+ self.requires_input_bounds = [0, 1]
+
+ @staticmethod
+ def _check_const_input(inp):
+ return isinstance(inp, (BoundConstant, BoundBuffers)) or (isinstance(inp, BoundParams) and inp.perturbation is None)
- @Bound.save_io_shape
def forward(self, x, y):
self.x_shape = x.shape
self.y_shape = y.shape
@@ -23,7 +42,6 @@ def get_bound_mul(x_l, x_u, y_l, y_u):
alpha_u = y_u
beta_u = x_l
gamma_u = -alpha_u * beta_u
-
return alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u
# Special case when input is x * x.
@@ -58,17 +76,62 @@ def _relax(x, y):
x_l, x_u = x.lower, x.upper
y_l, y_u = y.lower, y.upper
- # broadcast
- for k in [1, -1]:
- x_l = x_l + k * y_l
- x_u = x_u + k * y_u
- for k in [1, -1]:
- y_l = y_l + k * x_l
- y_u = y_u + k * x_u
+ # Broadcast
+ x_l = x_l + torch.zeros_like(y_l)
+ x_u = x_u + torch.zeros_like(y_u)
+ y_l = y_l + torch.zeros_like(x_l)
+ y_u = y_u + torch.zeros_like(x_u)
return BoundMul.get_bound_mul(x_l, x_u, y_l, y_u)
+ @staticmethod
+ def _multiply_by_const(x, const):
+ if isinstance(x, torch.Tensor):
+ return x * const
+ elif isinstance(x, Patches):
+ # Multiplies patches by a const. Assuming const is a tensor, and it must be in nchw format.
+ assert isinstance(const, torch.Tensor) and const.ndim == 4
+ if const.size(0) == x.patches.size(1) and const.size(1) == x.patches.size(-3) and const.size(2) == const.size(3) == 1:
+ # The case that we can do channel-wise broadcasting multiplication
+ # Shape of const: (batch, in_c, 1, 1)
+ # Shape of patches when unstable_idx is None: (spec, batch, in_c, patch_h, patch_w)
+ # Shape of patches when unstable_idx is not None: (out_c, batch, out_h, out_w, in_c, patch_h, patch_w)
+ const_reshaped = const
+ else:
+ assert x.unstable_idx is None and (x.padding == 0 or x.padding == [0,0,0,0]) and x.stride == 1 and x.patches.size(-1) == x.patches.size(-2) == 1
+ # The assumed dimension is (out_c, N, out_h, out_w, in_c, 1, 1) with padding =1 and stride = 0.
+ # In this special case we can directly multiply.
+ # After reshape it is (1, N, H, W, C, 1, 1)
+ const_reshaped = const.permute(0, 2, 3, 1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ return x.create_similar(x.patches * const_reshaped)
+ else:
+ raise ValueError(f'Unsupported x type {type(x)}')
+
+ @staticmethod
+ def bound_backward_constant(last_lA, last_uA, x, y, op=None):
+ op = BoundMul._multiply_by_const if op is None else op
+ # Handle the case of multiplication by a constant.
+ factor = None
+ if not BoundMul._check_const_input(x):
+ factor = y.value
+ if not BoundMul._check_const_input(y):
+ factor = x.value
+ # No need to compute A matrix if it is Constant.
+ lAx = None if BoundMul._check_const_input(x) or last_lA is None else op(last_lA, factor)
+ lAy = None if BoundMul._check_const_input(y) or last_lA is None else op(last_lA, factor)
+ uAx = None if BoundMul._check_const_input(x) or last_uA is None else op(last_uA, factor)
+ uAy = None if BoundMul._check_const_input(y) or last_uA is None else op(last_uA, factor)
+
+ return [(lAx, uAx), (lAy, uAy)], 0., 0.
+
+
def bound_backward(self, last_lA, last_uA, x, y):
+ if self.is_constant_op:
+ return self.bound_backward_constant(last_lA, last_uA, x, y)
+ else:
+ return self.bound_backward_both_perturbed(last_lA, last_uA, x, y)
+
+ def bound_backward_both_perturbed(self, last_lA, last_uA, x, y):
alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u = BoundMul._relax(x, y)
alpha_l, alpha_u = alpha_l.unsqueeze(0), alpha_u.unsqueeze(0)
@@ -79,14 +142,78 @@ def _bound_oneside(last_A,
alpha_neg, beta_neg, gamma_neg):
if last_A is None:
return None, None, 0
- last_A_pos, last_A_neg = last_A.clamp(min=0), last_A.clamp(max=0)
- A_x = last_A_pos * alpha_pos + last_A_neg * alpha_neg
- A_y = last_A_pos * beta_pos + last_A_neg * beta_neg
- last_A = last_A.reshape(last_A.shape[0], last_A.shape[1], -1)
- A_x = self.broadcast_backward(A_x, x)
- A_y = self.broadcast_backward(A_y, y)
- bias = self.get_bias(last_A_pos, gamma_pos) + \
- self.get_bias(last_A_neg, gamma_neg)
+
+ if type(last_A) == Patches:
+ # In patches mode, we need to unfold lower and upper slopes. In matrix mode we simply return.
+ def _maybe_unfold(d_tensor, last_A):
+ if d_tensor is None:
+ return None
+
+ d_shape = d_tensor.size()
+ # Reshape to 4-D tensor to unfold.
+ d_tensor = d_tensor.view(-1, *d_shape[-3:])
+ # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W).
+ d_unfolded = inplace_unfold(d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding)
+ # Reshape to (spec, batch, out_h, out_w, in_c, H, W); here spec_size is out_c.
+ d_unfolded_r = d_unfolded.view(*last_A.shape[:3], *d_unfolded.shape[1:])
+ if last_A.unstable_idx is not None:
+ if d_unfolded_r.size(0) == 1:
+ # Broadcast the spec shape, so only need to select the reset dimensions.
+ # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W).
+ d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)
+ d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ # output shape: (unstable_size, batch, in_c, H, W).
+ else:
+ d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W).
+ # For regular patches, the shape after unfold is (spec, batch, out_h, out_w, in_c, H, W).
+ return d_unfolded_r
+ # if last_A is not an identity matrix
+ assert last_A.identity == 0
+ if last_A.identity == 0:
+ # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension.
+ # for patches mode, we need to unfold the alpha_pos/neg and beta_pos/neg
+
+ alpha_pos = _maybe_unfold(alpha_pos, last_A)
+ alpha_neg = _maybe_unfold(alpha_neg, last_A)
+ beta_pos = _maybe_unfold(beta_pos, last_A)
+ beta_neg = _maybe_unfold(beta_neg, last_A)
+
+ gamma_pos = _maybe_unfold(gamma_pos, last_A)
+ gamma_neg = _maybe_unfold(gamma_neg, last_A)
+
+ patches = last_A.patches
+ patches_shape = patches.shape
+ A_x, bias = multiply_by_A_signs(patches.view(*patches_shape[:5], -1, *patches_shape[-2:]), alpha_pos, alpha_neg, gamma_pos, gamma_neg, patches_mode=True)
+ A_y, _ = multiply_by_A_signs(patches.view(*patches_shape[:5], -1, *patches_shape[-2:]), beta_pos, beta_neg, None, None, patches_mode=True)
+ A_x = A_x.view(patches_shape)
+ A_y = A_y.view(patches_shape)
+
+ # broadcast_backward
+ x_dims = []
+ y_dims = []
+
+ if A_x.shape[A_x.ndim-4] != x.output_shape[len(x.output_shape)-4]:
+ x_dims.append(A_x.ndim-4)
+
+ if A_y.shape[A_y.ndim-4] != y.output_shape[len(y.output_shape)-4]:
+ y_dims.append(A_y.ndim-4)
+
+ if len(x_dims) > 0:
+ A_x = A_x.sum(tuple(x_dims), keepdim=True)
+ if len(y_dims) > 0:
+ A_y = A_y.sum(tuple(y_dims), keepdim=True)
+
+ A_x = Patches(A_x, last_A.stride, last_A.padding, A_x.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape)
+ A_y = Patches(A_y, last_A.stride, last_A.padding, A_y.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape)
+ if type(last_A) == Tensor:
+ last_A_pos, last_A_neg = last_A.clamp(min=0), last_A.clamp(max=0)
+ A_x = last_A_pos * alpha_pos + last_A_neg * alpha_neg
+ A_y = last_A_pos * beta_pos + last_A_neg * beta_neg
+ A_x = self.broadcast_backward(A_x, x)
+ A_y = self.broadcast_backward(A_y, y)
+ bias = self.get_bias(last_A_pos, gamma_pos) + \
+ self.get_bias(last_A_neg, gamma_neg)
return A_x, A_y, bias
lA_x, lA_y, lbias = _bound_oneside(
@@ -96,8 +223,13 @@ def _bound_oneside(last_A,
return [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias
+ def bound_forward(self, dim_in, x, y):
+ if self.is_constant_op:
+ raise NotImplementedError
+ return self.bound_forward_both_perturbed(dim_in, x, y)
+
@staticmethod
- def bound_forward(dim_in, x, y):
+ def bound_forward_both_perturbed(dim_in, x, y):
x_lw, x_lb, x_uw, x_ub = x.lw, x.lb, x.uw, x.ub
y_lw, y_lb, y_uw, y_ub = y.lw, y.lb, y.uw, y.ub
@@ -120,7 +252,28 @@ def bound_forward(dim_in, x, y):
return LinearBound(lw, lb, uw, ub)
@staticmethod
- def interval_propagate(*v):
+ def interval_propagate_constant(*v, op=lambda x, const: x * const):
+ x, y = v[0], v[1]
+ x_is_const = x[0] is x[1] # FIXME: using better way to represent constant perturbation.
+ y_is_const = y[0] is y[1] # We should not check the distance between x[0] and x[1]. It's slow!
+ assert x_is_const or y_is_const
+ const = x[0] if x_is_const else y[0]
+ inp_lb = x[0] if y_is_const else y[0]
+ inp_ub = x[1] if y_is_const else y[1]
+ pos_mask = (const > 0).to(dtype=inp_lb.dtype)
+ neg_mask = 1. - pos_mask
+ lb = op(inp_lb, const * pos_mask) + op(inp_ub, const * neg_mask)
+ ub = op(inp_ub, const * pos_mask) + op(inp_lb, const * neg_mask)
+ return lb, ub
+
+ def interval_propagate(self, *v):
+ if self.is_constant_op:
+ return self.interval_propagate_constant(*v)
+ else:
+ return self.interval_propagate_both_perturbed(*v)
+
+ @staticmethod
+ def interval_propagate_both_perturbed(*v):
x, y = v[0], v[1]
if x is y:
# A shortcut for x * x.
@@ -133,27 +286,16 @@ def interval_propagate(*v):
l = F.relu(h_L) - F.relu(-h_U)
return l * l, torch.max(r0, r1)
- if Interval.use_relative_bounds(x) and Interval.use_relative_bounds(y):
- nominal = x.nominal * y.nominal
- lower_offset = (
- x.nominal.clamp(min=0) * (y.lower_offset) +
- x.nominal.clamp(max=0) * (y.upper_offset) +
- y.nominal.clamp(min=0) * (x.lower_offset) +
- y.nominal.clamp(max=0) * (x.upper_offset) +
- torch.min(x.lower_offset * y.upper_offset, x.upper_offset * y.lower_offset))
- upper_offset = (
- x.nominal.clamp(min=0) * (y.upper_offset) +
- x.nominal.clamp(max=0) * (y.lower_offset) +
- y.nominal.clamp(min=0) * (x.upper_offset) +
- y.nominal.clamp(max=0) * (x.lower_offset) +
- torch.max(x.lower_offset * y.lower_offset, x.upper_offset * y.upper_offset))
- return Interval(None, None, nominal=nominal, lower_offset=lower_offset, upper_offset=upper_offset)
-
r0, r1, r2, r3 = x[0] * y[0], x[0] * y[1], x[1] * y[0], x[1] * y[1]
lower = torch.min(torch.min(r0, r1), torch.min(r2, r3))
upper = torch.max(torch.max(r0, r1), torch.max(r2, r3))
return lower, upper
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ for vi in v:
+ assert isinstance(vi, Tensor), "build solver for BoundMul only with tensors for now"
+ self.solver_vars = v[0] * v[1]
+
@staticmethod
def infer_batch_dim(batch_size, *x):
if x[0] == -1:
@@ -166,13 +308,24 @@ def infer_batch_dim(batch_size, *x):
class BoundDiv(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.nonlinear = True
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.is_constant_op = False
+ for inp in inputs:
+ if isinstance(inp, (BoundConstant, BoundBuffers)):
+ # If any of the two inputs are constant, we do not need input bounds.
+ # FIXME (05/11/2022): this is just a temporary workaround. We need better way to determine whether we need input bounds, not just for BoundConstant.
+ # FIXME: unify this handling with BoundMul.
+ self.is_constant_op = True
+ if self.is_constant_op:
+ # One input is constant; no bounds required.
+ self.requires_input_bounds = []
+ else:
+ # Both inputs are perturbed. Need relaxation.
+ self.requires_input_bounds = [0, 1]
- @Bound.save_io_shape
def forward(self, x, y):
- # ad-hoc implementation for layer normalization
+ # FIXME (05/11/2022): ad-hoc implementation for layer normalization
if isinstance(self.inputs[1], BoundSqrt):
input = self.inputs[0].inputs[0]
x = input.forward_value
@@ -188,21 +341,37 @@ def forward(self, x, y):
return x / y
def bound_backward(self, last_lA, last_uA, x, y):
+ if self.is_constant_op:
+ return BoundMul.bound_backward_constant(last_lA, last_uA, x, y, op=lambda x, const: BoundMul._multiply_by_const(x, 1/const))
+ else:
+ return self.bound_backward_both_perturbed(last_lA, last_uA, x, y)
+
+ def bound_backward_both_perturbed(self, last_lA, last_uA, x, y):
reciprocal, mul, y_r = self._convert_to_mul(x, y)
A, lower_b, upper_b = mul.bound_backward(last_lA, last_uA, x, y_r)
-
A_y, lower_b_y, upper_b_y = reciprocal.bound_backward(A[1][0], A[1][1], y)
+ if isinstance(upper_b_y, Tensor) and upper_b_y.ndim == 1:
+ upper_b_y = upper_b_y.unsqueeze(-1)
+ if isinstance(lower_b_y, Tensor) and lower_b_y.ndim == 1:
+ lower_b_y = lower_b_y.unsqueeze(-1)
upper_b = upper_b + upper_b_y
lower_b = lower_b + lower_b_y
-
return [A[0], A_y[0]], lower_b, upper_b
def bound_forward(self, dim_in, x, y):
+ assert not self.is_constant_op
reciprocal, mul, y_r = self._convert_to_mul(x, y)
y_r_linear = reciprocal.bound_forward(dim_in, y)
- y_r_linear = y_r_linear._replace(lower=y_r.lower, upper=y_r.upper)
+ y_r_linear.lower = y_r.lower
+ y_r_linear.upper = y_r.upper
return mul.bound_forward(dim_in, x, y_r_linear)
+ def interval_propagate(self, *v):
+ if self.is_constant_op:
+ return BoundMul.interval_propagate_constant(*v, op=lambda x, const: x / const)
+ else:
+ return self.interval_propagate_both_perturbed(*v)
+
def interval_propagate(self, *v):
# ad-hoc implementation for layer normalization
"""
@@ -215,28 +384,28 @@ def interval_propagate(self, *v):
1 / ( sqrt (1/n * sum_j Upper{(x_j-mu)^2/(x_i-mu)^2} ))
Lower{(x_j-mu)^2/(x_i-mu)^2}
- Lower{sum_j (x_j-mu)^2} / Upper{(x_i-mu)^2}
+ Lower{sum_j (x_j-mu)^2} / Upper{(x_i-mu)^2}
Upper{(x_j-mu)^2/(x_i-mu)^2}
- Upper{sum_j (x_j-mu)^2} / Lower{(x_i-mu)^2}
- """
+ Upper{sum_j (x_j-mu)^2} / Lower{(x_i-mu)^2}
+ """
if isinstance(self.inputs[1], BoundSqrt):
input = self.inputs[0].inputs[0]
n = input.forward_value.shape[-1]
-
+
h_L, h_U = input.lower, input.upper
dev_lower = (
- h_L * (1 - 1. / n) -
+ h_L * (1 - 1. / n) -
(h_U.sum(dim=-1, keepdim=True) - h_U) / n
)
dev_upper = (
- h_U * (1 - 1. / n) -
+ h_U * (1 - 1. / n) -
(h_L.sum(dim=-1, keepdim=True) - h_L) / n
)
- dev_sqr_lower = (1 - (dev_lower < 0).float() * (dev_upper > 0).float()) * \
- torch.min(dev_lower.abs(), dev_upper.abs())**2
+ dev_sqr_lower = (1 - (dev_lower < 0).to(dev_lower.dtype) * (dev_upper > 0).to(dev_lower.dtype)) * \
+ torch.min(dev_lower.abs(), dev_upper.abs())**2
dev_sqr_upper = torch.max(dev_lower.abs(), dev_upper.abs())**2
sum_lower = (dev_sqr_lower.sum(dim=-1, keepdim=True) - dev_sqr_lower) / dev_sqr_upper.clamp(min=epsilon)
@@ -245,8 +414,8 @@ def interval_propagate(self, *v):
dev_sqr_lower.clamp(min=epsilon)
sqrt_upper = torch.sqrt(1. / n * (sum_upper + 1))
- lower = (dev_lower < 0).float() * (-1. / sqrt_lower) + (dev_lower > 0).float() * (1. / sqrt_upper)
- upper = (dev_upper > 0).float() * (1. / sqrt_lower) + (dev_upper < 0).float() * (-1. / sqrt_upper)
+ lower = (dev_lower < 0).to(dev_lower.dtype) * (-1. / sqrt_lower) + (dev_lower > 0).to(dev_lower.dtype) * (1. / sqrt_upper)
+ upper = (dev_upper > 0).to(dev_upper.dtype) * (1. / sqrt_lower) + (dev_upper < 0).to(dev_upper.dtype) * (-1. / sqrt_upper)
return lower, upper
@@ -256,33 +425,35 @@ def interval_propagate(self, *v):
def _convert_to_mul(self, x, y):
try:
- reciprocal = BoundReciprocal(self.input_name, self.name + '/reciprocal', self.ori_name, {}, [], 0, None,
- self.device)
- mul = BoundMul(self.input_name, self.name + '/mul', self.ori_name, {}, [], 0, None, self.device)
+ reciprocal = BoundReciprocal({}, [], 0, None)
+ mul = BoundMul({}, [], 0, None)
except:
# to make it compatible with previous code
- reciprocal = BoundReciprocal(self.input_name, self.name + '/reciprocal', None, {}, [], 0, None, self.device)
- mul = BoundMul(self.input_name, self.name + '/mul', None, {}, [], 0, None, self.device)
+ reciprocal = BoundReciprocal(None, {}, [], 0, None)
+ mul = BoundMul(None, {}, [], 0, None)
reciprocal.output_shape = mul.output_shape = self.output_shape
reciprocal.batch_dim = mul.batch_dim = self.batch_dim
y_r = copy.copy(y)
if isinstance(y_r, LinearBound):
- y_r = y_r._replace(lower=1. / y.upper, upper=1. / y.lower)
+ y_r.lower = 1. / y.upper
+ y_r.upper = 1. / y.lower
else:
y_r.lower = 1. / y.upper
y_r.upper = 1. / y.lower
return reciprocal, mul, y_r
- def infer_batch_dim(self, batch_size, *x):
- return BoundMul.infer_batch_dim(batch_size, *x)
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ for vi in v:
+ assert isinstance(vi, Tensor), "build solver for BoundDiv only with tensors for now"
+ self.solver_vars = v[0] / v[1]
class BoundAdd(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ # FIXME: This is not the right way to enable patches mode. Instead we must traverse the graph and determine when patches mode needs to be used.
self.mode = options.get("conv_mode", "matrix")
- @Bound.save_io_shape
def forward(self, x, y):
self.x_shape = x.shape
self.y_shape = y.shape
@@ -301,33 +472,59 @@ def _bound_oneside(last_A, w):
return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0
def bound_forward(self, dim_in, x, y):
- x_lw, x_lb, x_uw, x_ub = Bound.broadcast_forward(dim_in, x, self.output_shape)
- y_lw, y_lb, y_uw, y_ub = Bound.broadcast_forward(dim_in, y, self.output_shape)
- lw, lb = x_lw + y_lw, x_lb + y_lb
- uw, ub = x_uw + y_uw, x_ub + y_ub
- return LinearBound(lw, lb, uw, ub)
+ lb, ub = x.lb + y.lb, x.ub + y.ub
- def interval_propagate(self, x, y):
- assert (not isinstance(y, torch.Tensor))
+ def add_w(x_w, y_w, x_b, y_b):
+ if x_w is None and y_w is None:
+ return None
+ elif x_w is not None and y_w is not None:
+ return x_w + y_w
+ elif y_w is None:
+ return x_w + torch.zeros_like(y_b)
+ else:
+ return y_w + torch.zeros_like(x_b)
- if Interval.use_relative_bounds(x) and Interval.use_relative_bounds(y):
- return Interval(
- None, None,
- x.nominal + y.nominal,
- x.lower_offset + y.lower_offset,
- x.upper_offset + y.upper_offset)
+ lw = add_w(x.lw, y.lw, x.lb, y.lb)
+ uw = add_w(x.uw, y.uw, x.ub, y.ub)
- return x[0] + y[0], x[1] + y[1]
+ return LinearBound(lw, lb, uw, ub)
- def infer_batch_dim(self, batch_size, *x):
- return BoundMul.infer_batch_dim(batch_size, *x)
+ def interval_propagate(self, x, y):
+ assert (not isinstance(y, Tensor))
+ return x[0] + y[0], x[1] + y[1]
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ if isinstance(v[0], Tensor) and isinstance(v[1], Tensor):
+ # constants if both inputs are tensors
+ self.solver_vars = self.forward(v[0], v[1])
+ return
+ # we have both gurobi vars as inputs
+ this_layer_shape = self.output_shape
+ gvar_array1 = np.array(v[0])
+ gvar_array2 = np.array(v[1])
+ assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:]
+
+ # flatten to create vars and constrs first
+ gvar_array1 = gvar_array1.reshape(-1)
+ gvar_array2 = gvar_array2.reshape(-1)
+ new_layer_gurobi_vars = []
+ for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)):
+ var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0,
+ vtype=grb.GRB.CONTINUOUS,
+ name=f'lay{self.name}_{neuron_idx}')
+ model.addConstr(var == (var1 + var2), name=f'lay{self.name}_{neuron_idx}_eq')
+ new_layer_gurobi_vars.append(var)
+
+ # reshape to the correct list shape of solver vars
+ self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist()
+ model.update()
class BoundSub(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ # FIXME: This is not the right way to enable patches mode. Instead we must traverse the graph and determine when patches mode needs to be used.
+ self.mode = options.get("conv_mode", "matrix")
- @Bound.save_io_shape
def forward(self, x, y):
self.x_shape = x.shape
self.y_shape = y.shape
@@ -337,7 +534,17 @@ def bound_backward(self, last_lA, last_uA, x, y):
def _bound_oneside(last_A, w, sign=-1):
if last_A is None:
return None
- return self.broadcast_backward(sign * last_A, w)
+ if isinstance(last_A, torch.Tensor):
+ return self.broadcast_backward(sign * last_A, w)
+ elif isinstance(last_A, Patches):
+ if sign == 1:
+ # Patches shape requires no broadcast.
+ return last_A
+ else:
+ # Multiply by the sign.
+ return last_A.create_similar(sign * last_A.patches)
+ else:
+ raise ValueError(f'Unknown last_A type {type(last_A)}')
uA_x = _bound_oneside(last_uA, x, sign=1)
uA_y = _bound_oneside(last_uA, y, sign=-1)
@@ -346,32 +553,55 @@ def _bound_oneside(last_A, w, sign=-1):
return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0
def bound_forward(self, dim_in, x, y):
- x_lw, x_lb, x_uw, x_ub = Bound.broadcast_forward(dim_in, x, self.output_shape)
- y_lw, y_lb, y_uw, y_ub = Bound.broadcast_forward(dim_in, y, self.output_shape)
- lw, lb = x_lw - y_uw, x_lb - y_ub
- uw, ub = x_uw - y_lw, x_ub - y_lb
+ lb, ub = x.lb - y.ub, x.ub - y.lb
+
+ def add_w(x_w, y_w, x_b, y_b):
+ if x_w is None and y_w is None:
+ return None
+ elif x_w is not None and y_w is not None:
+ return x_w + y_w
+ elif y_w is None:
+ return x_w + torch.zeros_like(y_b)
+ else:
+ return y_w + torch.zeros_like(x_b)
+
+ lw = add_w(x.lw, -y.uw, x.lb, y.lb)
+ uw = add_w(x.uw, -y.lw, x.ub, y.ub)
+
return LinearBound(lw, lb, uw, ub)
def interval_propagate(self, x, y):
- if Interval.use_relative_bounds(x) and Interval.use_relative_bounds(y):
- return Interval(
- None, None,
- x.nominal - y.nominal,
- x.lower_offset - y.upper_offset,
- x.upper_offset - y.lower_offset)
-
return x[0] - y[1], x[1] - y[0]
- def infer_batch_dim(self, batch_size, *x):
- return BoundMul.infer_batch_dim(batch_size, *x)
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ if isinstance(v[0], Tensor) and isinstance(v[1], Tensor):
+ # constants if both inputs are tensors
+ self.solver_vars = self.forward(v[0], v[1])
+ return
+ # we have both gurobi vars as inputs
+ this_layer_shape = self.output_shape
+ gvar_array1 = np.array(v[0])
+ gvar_array2 = np.array(v[1])
+ assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:]
+
+ # flatten to create vars and constrs first
+ gvar_array1 = gvar_array1.reshape(-1)
+ gvar_array2 = gvar_array2.reshape(-1)
+ new_layer_gurobi_vars = []
+ for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)):
+ var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0,
+ vtype=grb.GRB.CONTINUOUS,
+ name=f'lay{self.name}_{neuron_idx}')
+ model.addConstr(var == (var1 - var2), name=f'lay{self.name}_{neuron_idx}_eq')
+ new_layer_gurobi_vars.append(var)
+
+ # reshape to the correct list shape of solver vars
+ self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist()
+ model.update()
class BoundEqual(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
- @Bound.save_io_shape
def forward(self, x, y):
return x == y
-
- def infer_batch_dim(self, batch_size, *x):
- return BoundMul.infer_batch_dim(batch_size, *x)
\ No newline at end of file
diff --git a/auto_LiRPA/operators/clampmult.py b/auto_LiRPA/operators/clampmult.py
new file mode 100644
index 0000000..7241fb6
--- /dev/null
+++ b/auto_LiRPA/operators/clampmult.py
@@ -0,0 +1,238 @@
+"""Element multiplication with the A matrix based on its sign."""
+import torch
+import time
+from typing import Optional, Tuple
+from torch import Tensor
+from ..patches import Patches
+
+
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_set_profiling_mode(False)
+
+
+# @torch.jit.script
+def _reference_multiply_by_A_signs(A: Tensor, d_pos: Tensor, d_neg: Tensor,
+ b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool) -> Tuple[Tensor, Tensor]:
+ """Reference implementation."""
+ A_pos = A.clamp(min=0)
+ A_neg = A.clamp(max=0)
+ A_new = d_pos * A_pos + d_neg * A_neg
+ bias_pos = bias_neg = torch.tensor(0.)
+ if b_pos is not None:
+ if patches_mode:
+ bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos)
+ else:
+ bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos)
+ if b_neg is not None:
+ if patches_mode:
+ bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg)
+ else:
+ bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg)
+ return A_new, bias_pos + bias_neg
+
+
+class ClampedMultiplication(torch.autograd.Function):
+ @staticmethod
+ @torch.jit.script
+ def clamp_mutiply_forward(A: Tensor, d_pos: Tensor, d_neg: Tensor,
+ b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool) -> Tuple[Tensor, Tensor]:
+ """Forward operations; actually the same as the reference implementation."""
+ A_pos = A.clamp(min=0)
+ A_neg = A.clamp(max=0)
+ A_new = d_pos * A_pos + d_neg * A_neg
+ bias_pos = bias_neg = torch.tensor(0.)
+ if b_pos is not None:
+ if patches_mode:
+ bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos)
+ else:
+ bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos)
+ if b_neg is not None:
+ if patches_mode:
+ bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg)
+ else:
+ bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg)
+ return A_new, bias_pos + bias_neg
+
+ @staticmethod
+ @torch.jit.script
+ def clamp_mutiply_backward(A: Tensor, d_pos: Tensor, d_neg: Tensor,
+ b_pos: Optional[Tensor], b_neg: Optional[Tensor], grad_output_A: Tensor, grad_output_bias: Optional[Tensor],
+ patches_mode: bool) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], None]:
+ """Improved backward operation. This should be better than the backward function generated by Pytorch."""
+ if grad_output_bias is not None:
+ extension_dim = len(A.shape) - len(grad_output_bias.shape)
+ grad_output_bias = grad_output_bias.view(grad_output_bias.shape + (1, ) * extension_dim)
+ A_pos_mask = (A >= 0).to(dtype=grad_output_A.dtype)
+ A_neg_mask = 1. - A_pos_mask
+ A_pos_grad_output_A = A_pos_mask * grad_output_A
+ A_neg_grad_output_A = A_neg_mask * grad_output_A
+ gd_pos = A * A_pos_grad_output_A
+ gd_neg = A * A_neg_grad_output_A
+ if b_pos is not None and b_neg is not None and grad_output_bias is not None:
+ A_pos_grad_output_bias = A_pos_mask * grad_output_bias
+ A_neg_grad_output_bias = A_neg_mask * grad_output_bias
+ gb_neg = A * A_neg_grad_output_bias
+ gb_pos = A * A_pos_grad_output_bias
+ # gA has 4 terms.
+ gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_pos * A_pos_grad_output_bias + b_neg * A_neg_grad_output_bias
+ elif b_neg is not None and grad_output_bias is not None:
+ A_neg_grad_output_bias = A_neg_mask * grad_output_bias
+ gb_neg = A * A_neg_grad_output_bias
+ gb_pos = None
+ # gA has 3 terms.
+ gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_neg * A_neg_grad_output_bias
+ elif b_pos is not None and grad_output_bias is not None:
+ A_pos_grad_output_bias = A_pos_mask * grad_output_bias
+ gb_pos = A * A_pos_grad_output_bias
+ gb_neg = None
+ # gA has 3 terms.
+ gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_pos * A_pos_grad_output_bias
+ else:
+ # gA has 2 terms.
+ gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A
+ gb_pos = gb_neg = None
+ return gA, gd_pos, gd_neg, gb_pos, gb_neg, None
+
+ @staticmethod
+ def forward(ctx, A, d_pos, d_neg, b_pos, b_neg, patches_mode):
+ # No need to save the intermediate A_pos, A_neg as they have been fused into the computation.
+ ctx.save_for_backward(A, d_pos, d_neg, b_pos, b_neg)
+ ctx.patches_mode = patches_mode
+ return ClampedMultiplication.clamp_mutiply_forward(A, d_pos, d_neg, b_pos, b_neg, patches_mode)
+
+ @staticmethod
+ def backward(ctx, grad_output_A, grad_output_bias):
+ A, d_pos, d_neg, b_pos, b_neg = ctx.saved_tensors
+ patches_mode = ctx.patches_mode
+ return ClampedMultiplication.clamp_mutiply_backward(A, d_pos, d_neg, b_pos, b_neg, grad_output_A, grad_output_bias, patches_mode)
+
+
+def multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, contiguous='auto'):
+ if isinstance(A, Tensor):
+ if contiguous is True or contiguous == 'auto':
+ # For dense mode, convert d_pos and d_neg to contiguous tensor by default.
+ d_pos = d_pos.contiguous()
+ d_neg = d_neg.contiguous()
+ if d_pos.ndim == 1:
+ # Special case for LSTM, the bias term is 1-dimension. (FIXME)
+ assert d_neg.ndim == 1 and b_pos.ndim == 1 and b_neg.ndim == 1
+ new_A = A.clamp(min=0) * d_pos + A.clamp(max=0) * d_neg
+ new_bias = A.clamp(min=0) * b_pos + A.clamp(max=0) * b_neg
+ return new_A, new_bias
+ return ClampedMultiplication.apply(A, d_pos, d_neg, b_pos, b_neg, False)
+ elif isinstance(A, Patches):
+ if contiguous:
+ # For patches mode, do not convert d_pos and d_neg to contiguous tensor by default.
+ d_pos = d_pos.contiguous()
+ d_neg = d_neg.contiguous()
+ assert A.identity == 0 # TODO: handle the A.identity = 1 case. Currently not used.
+ patches = A.patches
+ patches_shape = patches.shape
+ # patches shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension.
+ # or (unstable_size, batch_size, in_c, H, W) when it is sparse.
+ if len(patches_shape) == 6:
+ patches = patches.view(*patches_shape[:2], -1, *patches_shape[-2:])
+ d_pos = d_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) if d_pos is not None else None
+ d_neg = d_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) if d_neg is not None else None
+ b_pos = b_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_pos is not None else None
+ b_neg = b_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_neg is not None else None
+ # Apply the multiplication based on signs.
+ A_prod, bias = ClampedMultiplication.apply(patches, d_pos, d_neg, b_pos, b_neg, True)
+ # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse.
+ # For sparse patches the return bias size is (unstable_size, batch).
+ # For regular patches the return bias size is (spec, batch, out_h, out_w).
+ if len(patches_shape) == 6:
+ A_prod = A_prod.view(*patches_shape)
+ return A.create_similar(A_prod), bias
+
+
+def _speed_test(A, d_pos, d_neg, b_pos, b_neg, patches_mode=False, n_test=20, warmup=3):
+ """Benchmarking function."""
+ print(f'patches_mode = {patches_mode}, b_pos is {type(b_pos)}, b_neg is {type(b_neg)}')
+ total_ref = 0.
+ total_new = 0.
+ run = ['ref', 'new']
+ for i in range(n_test):
+ ref_time = new_time = 0.
+
+ if 'ref' in run:
+ torch.cuda.synchronize()
+ start = time.time()
+ ref_A, ref_bias = _reference_multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, patches_mode)
+ ref_loss = ref_A.sum() + ref_bias.sum()
+ ref_loss.backward()
+ torch.cuda.synchronize()
+ ref_time = time.time() - start
+ ref_gA = A.grad.detach().clone()
+ ref_gd_pos = d_pos.grad.detach().clone()
+ ref_gd_neg = d_neg.grad.detach().clone()
+ ref_gb_pos = b_pos.grad.detach().clone() if b_pos is not None else torch.tensor(0.)
+ ref_gb_neg = b_neg.grad.detach().clone() if b_neg is not None else torch.tensor(0.)
+ A.grad = d_pos.grad = d_neg.grad = None
+ if b_pos is not None:
+ b_pos.grad = None
+ if b_neg is not None:
+ b_neg.grad = None
+ del ref_loss
+
+ if 'new' in run:
+ torch.cuda.synchronize()
+ start = time.time()
+ new_A, new_bias = multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, patches_mode)
+ new_loss = new_A.sum() + new_bias.sum()
+ new_loss.backward()
+ torch.cuda.synchronize()
+ new_time = time.time() - start
+ new_gA = A.grad.detach().clone()
+ new_gd_pos = d_pos.grad.detach().clone()
+ new_gd_neg = d_neg.grad.detach().clone()
+ new_gb_pos = b_pos.grad.detach().clone() if b_pos is not None else torch.tensor(0.)
+ new_gb_neg = b_neg.grad.detach().clone() if b_neg is not None else torch.tensor(0.)
+ A.grad = d_pos.grad = d_neg.grad = None
+ if b_pos is not None:
+ b_pos.grad = None
+ if b_neg is not None:
+ b_neg.grad = None
+ del new_loss
+
+ print(f'Loop {i:3d} {"(warmup)" if i < warmup else " "} time ref {ref_time:.5f} new {new_time:.6f} speedup {ref_time / new_time if i >= warmup else float("nan"):.3f}')
+ if i >= warmup:
+ total_ref += ref_time
+ total_new += new_time
+
+ if 'ref' in run and 'new' in run:
+ A_diff = (ref_A - new_A).abs().sum().item() / ref_A.abs().sum().item()
+ gA_diff = (ref_gA - new_gA).abs().sum().item() / ref_gA.abs().sum().item()
+ bias_diff = (ref_bias - new_bias).abs().sum().item() / (ref_bias.abs().sum().item() + 1e-10)
+ gd_pos_diff = (ref_gd_pos - new_gd_pos).abs().sum().item() / ref_gd_pos.abs().sum().item()
+ gd_neg_diff = (ref_gd_neg - new_gd_neg).abs().sum().item() / ref_gd_neg.abs().sum().item()
+ gb_pos_diff = (ref_gb_pos - new_gb_pos).abs().sum().item() / (ref_gb_pos.abs().sum().item() + 1e-10)
+ gb_neg_diff = (ref_gb_neg - new_gb_neg).abs().sum().item() / (ref_gb_neg.abs().sum().item() + 1e-10)
+ print(f' diff {A_diff} {gA_diff} {bias_diff} {gd_pos_diff} {gd_neg_diff} {gb_pos_diff} {gb_neg_diff}')
+ assert A_diff < 1e-6 and bias_diff < 1e-6 and gA_diff < 1e-6 and gd_pos_diff < 1e-6 and gd_neg_diff < 1e-6
+ assert gb_pos_diff < 1e-6 and gb_neg_diff < 1e-6
+
+
+ avg_ref_time = total_ref / (n_test - warmup)
+ avg_new_time = total_new / (n_test - warmup)
+ print(f'Avg. time: reference {avg_ref_time:.5f} new {avg_new_time:.6f} speedup {avg_ref_time / avg_new_time:.3f}')
+
+
+if __name__ == '__main__':
+ for patches_mode in [True, False]:
+ if patches_mode:
+ shape = (256, 8, 8, 8, 16, 32)
+ else:
+ shape = (256, 8, 128, 256)
+ A = torch.randn(shape, device='cuda', requires_grad=True)
+ d_pos = torch.randn(shape, device='cuda', requires_grad=True)
+ d_neg = torch.randn(shape, device='cuda', requires_grad=True)
+ b_pos = torch.randn(shape, device='cuda', requires_grad=True)
+ b_neg = torch.randn(shape, device='cuda', requires_grad=True)
+ _speed_test(A, d_pos, d_neg, None, None, patches_mode=patches_mode)
+ _speed_test(A, d_pos, d_neg, None, b_neg, patches_mode=patches_mode)
+ _speed_test(A, d_pos, d_neg, b_pos, None, patches_mode=patches_mode)
+ _speed_test(A, d_pos, d_neg, b_pos, b_neg, patches_mode=patches_mode)
+ print('Press Enter key to continue.')
+ input()
+ del A, d_pos, d_neg, b_pos, b_neg
diff --git a/auto_LiRPA/operators/constant.py b/auto_LiRPA/operators/constant.py
index 6d01883..9b6b52f 100644
--- a/auto_LiRPA/operators/constant.py
+++ b/auto_LiRPA/operators/constant.py
@@ -2,12 +2,11 @@
from .base import *
class BoundConstant(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.value = attr['value'].to(self.device)
self.use_default_ibp = True
- @Bound.save_io_shape
def forward(self):
return self.value.to(self.device)
@@ -19,7 +18,7 @@ def _bound_oneside(A):
if A is None:
return 0.0
- if type(A) == torch.Tensor:
+ if type(A) == Tensor:
if A.ndim > 2:
A = torch.sum(A, dim=list(range(2, A.ndim)))
elif type(A) == Patches:
@@ -39,22 +38,22 @@ def bound_forward(self, dim_in):
lb = ub = self.value
return LinearBound(lw, lb, uw, ub)
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ self.solver_vars = self.value
+
+
class BoundPrimConstant(Bound):
- def __init__(self, input_name, name, ori_name, attr, input, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, input, output_index, options, device)
- self.value = attr['value']
+ def __init__(self, attr, input, output_index, options):
+ super().__init__(attr, input, output_index, options)
- @Bound.save_io_shape
def forward(self):
return torch.tensor([], device=self.device)
class BoundConstantOfShape(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.device = device
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.value = attr['value'].to(self.device)
- @Bound.save_io_shape
def forward(self, x):
self.x = x
self.from_input = True
@@ -85,9 +84,13 @@ def bound_forward(self, dim_in, x):
def interval_propagate(self, *v):
self.x = v[0][0]
- value = torch.ones(list(v[0][0]), device=self.device) * self.value
+ size = int(v[0][0].item()) if isinstance(v[0][0], Tensor) else v[0][0]
+ value = torch.ones(size, device=self.device) * self.value
return value, value
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ self.solver_vars = self.forward(v)
+
def infer_batch_dim(self, batch_size, *x):
# FIXME Should avoid referring to batch_size; Treat `torch.Size` results differently
if self.x[0] == batch_size:
@@ -96,10 +99,10 @@ def infer_batch_dim(self, batch_size, *x):
return -1
class BoundRange(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.device = attr['device']
- @Bound.save_io_shape
def forward(self, start, end, step):
if start.dtype == end.dtype == step.dtype == torch.int64:
return torch.arange(start, end, step, dtype=torch.int64, device=self.device)
@@ -111,10 +114,10 @@ def infer_batch_dim(self, batch_size, *x):
return -1
class BoundATenDiag(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.device = attr['device']
- @Bound.save_io_shape
def forward(self, x, diagonal=0):
return torch.diag(x, diagonal=diagonal)
@@ -125,10 +128,10 @@ def infer_batch_dim(self, batch_size, *x):
return 1 # This is not a batch operation.
class BoundATenDiagonal(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.device = attr['device']
- @Bound.save_io_shape
def forward(self, x, offset=0, dim1=0, dim2=1):
return torch.diagonal(x, offset=offset, dim1=dim1, dim2=dim2)
diff --git a/auto_LiRPA/operators/convolution.py b/auto_LiRPA/operators/convolution.py
index 6bc5d3f..22343f1 100644
--- a/auto_LiRPA/operators/convolution.py
+++ b/auto_LiRPA/operators/convolution.py
@@ -1,14 +1,16 @@
-""" Convolution, pooling and padding operators"""
+""" Convolution and padding operators"""
from .base import *
-from .activation import BoundOptimizableActivation
+import numpy as np
+from .solver_utils import grb
+from ..patches import unify_shape, compute_patches_stride_padding, is_shape_used
class BoundConv(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
+ def __init__(self, attr, inputs, output_index, options):
assert (attr['pads'][0] == attr['pads'][2])
assert (attr['pads'][1] == attr['pads'][3])
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ super().__init__(attr, inputs, output_index, options)
self.stride = attr['strides']
self.padding = [attr['pads'][0], attr['pads'][1]]
@@ -18,14 +20,13 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio
self.has_bias = True
else:
self.has_bias = False
- self.to(device)
+ self.relu_followed = False
+ self.patches_start = True
self.mode = options.get("conv_mode", "matrix")
- self.relu_followed = False
# denote whether this Conv is followed by a ReLU
- # if self.relu_followed is False, we need to manually pad the conv patches.
+ # if self.relu_followed is False, we need to manually pad the conv patches.
# If self.relu_followed is True, the patches are padded in the ReLU layer and the manual padding is not needed.
- @Bound.save_io_shape
def forward(self, *x):
# x[0]: input, x[1]: weight, x[2]: bias if self.has_bias
bias = x[2] if self.has_bias else None
@@ -46,14 +47,14 @@ def _bound_oneside(last_A):
# Conv layer does not support the OneHotC fast path. We have to create a dense matrix instead.
shape = last_A.shape # [spec, batch, C, H, W]
dim = int(prod(shape[2:]))
- dense_last_A = torch.zeros(size=(shape[0], shape[1], dim), device=last_A.device)
+ dense_last_A = torch.zeros(size=(shape[0], shape[1], dim), device=last_A.device, dtype=weight.dtype)
# last_A.index has size (spec, batch), its values are the index of the one-hot non-zero elements in A.
# last_A.coeffs is the value of the non-zero element.
dense_last_A = torch.scatter(dense_last_A, dim=2, index=last_A.index.unsqueeze(-1), src=last_A.coeffs.unsqueeze(-1))
# We created a large A matrix and it will be handled below.
last_A = dense_last_A.view(shape[0], shape[1], *shape[2:])
- if type(last_A) == torch.Tensor:
+ if type(last_A) == Tensor:
shape = last_A.size()
# when (W−F+2P)%S != 0, construct the output_padding
output_padding0 = int(self.input_shape[2]) - (int(self.output_shape[2]) - 1) * self.stride[0] + 2 * \
@@ -65,7 +66,8 @@ def _bound_oneside(last_A):
groups=self.groups, output_padding=(output_padding0, output_padding1))
next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:])
if self.has_bias:
- sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2)
+ # sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2)
+ sum_bias = torch.einsum('sbchw,c->sb', last_A, x[2].lower)
else:
sum_bias = 0
return next_A, sum_bias
@@ -76,13 +78,15 @@ def _bound_oneside(last_A):
if not self.relu_followed: # FIXME (09/20): Don't call it relu_followed. Instead, make this a property of A, called "padded" and propagate this property.
# The last_A.patches was not padded, so we need to pad them here.
# If this Conv layer is followed by a ReLU layer, then the padding was already handled there and there is no need to pad again.
- one_d = torch.ones(tuple(1 for i in self.output_shape), device=last_A.patches.device).expand(self.output_shape)
- # After unfolding, the shape is (batch, out_h, out_w, in_c, h, w)
- one_d_unfolded = inplace_unfold(one_d, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding)
+ one_d = torch.ones(tuple(1 for i in self.output_shape[1:]), device=last_A.patches.device, dtype=weight.dtype).expand(self.output_shape[1:])
+ # Add batch dimension.
+ one_d = one_d.unsqueeze(0)
+ # After unfolding, the shape is (1, out_h, out_w, in_c, h, w)
+ one_d_unfolded = inplace_unfold(one_d, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding)
if last_A.unstable_idx is not None:
# Move out_h, out_w dimension to the front for easier selection.
one_d_unfolded_r = one_d_unfolded.permute(1, 2, 0, 3, 4, 5)
- # for sparse patches the shape is (unstable_size, batch, in_c, h, w).
+ # for sparse patches the shape is (unstable_size, batch, in_c, h, w). Batch size is 1 so no need to select here.
one_d_unfolded_r = one_d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]
else:
# Append the spec dimension.
@@ -100,7 +104,7 @@ def _bound_oneside(last_A):
sum_bias = 0
flattened_patches = patches.reshape(-1, patches.size(-3), patches.size(-2), patches.size(-1))
- pieces = F.conv_transpose2d(flattened_patches, weight, stride=self.stride)
+ pieces = F.conv_transpose2d(flattened_patches, insert_zeros(weight, last_A.inserted_zeros), stride=self.stride)
# New patch size: (out_c, batch, out_h, out_w, c, h, w) or (unstable_size, batch, c, h, w).
pieces = pieces.view(*patches.shape[:-3], pieces.size(-3), pieces.size(-2), pieces.size(-1))
@@ -114,37 +118,43 @@ def _bound_oneside(last_A):
# Expand the batch dimnension.
pieces = pieces.expand(-1, last_A.shape[1], -1, -1, -1)
# Do the same for the bias.
- sum_bias = x[2].lower[last_A.unstable_idx[0]].unsqueeze(-1)
- # bias has shape (unstable_size, batch).
- sum_bias = sum_bias.expand(-1, last_A.shape[1])
+ if self.has_bias:
+ sum_bias = x[2].lower[last_A.unstable_idx[0]].unsqueeze(-1)
+ # bias has shape (unstable_size, batch).
+ sum_bias = sum_bias.expand(-1, last_A.shape[1])
+ else:
+ sum_bias = 0
else:
assert weight.size(0) == last_A.shape[0]
pieces = weight.view(weight.size(0), 1, 1, 1, weight.size(1), weight.size(2), weight.size(3)).expand(-1, *last_A.shape[1:4], -1, -1, -1)
# The bias (x[2].lower) has shape (out_c,) need to make it (out_c, batch, out_h, out_w).
# Here we should transpose sum_bias to set the batch dim to 1, aiming to keep it consistent with the matrix version
- sum_bias = x[2].lower.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4])
+ if self.has_bias:
+ sum_bias = x[2].lower.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4])
+ else:
+ sum_bias = 0
else:
raise NotImplementedError()
padding = last_A.padding if last_A is not None else (0, 0, 0, 0) # (left, right, top, bottom)
- stride = last_A.stride if last_A is not None else 1
+ stride = last_A.stride if last_A is not None else (1, 1)
+ inserted_zeros = last_A.inserted_zeros if last_A is not None else 0
+ output_padding = last_A.output_padding if last_A is not None else (0, 0, 0, 0)
- if type(padding) == int:
- padding = padding * self.stride[0] + self.padding[0]
- else:
- padding = tuple(p * self.stride[0] + self.padding[0] for p in padding)
- stride *= self.stride[0]
+ padding, stride, output_padding = compute_patches_stride_padding(self.input_shape, padding, stride, self.padding, self.stride, inserted_zeros, output_padding)
- if pieces.shape[-1] > self.input_shape[-1]: # the patches is too large and from now on, we will use matrix mode instead of patches mode.
+ if inserted_zeros == 0 and not is_shape_used(output_padding) and pieces.shape[-1] > self.input_shape[-1]: # the patches is too large and from now on, we will use matrix mode instead of patches mode.
# This is our desired matrix: the input will be flattend to (batch_size, input_channel*input_x * input_y) and multiplies on this matrix.
# After multiplication, the desired output is (batch_size, out_channel*output_x*output_y).
# A_matrix has size (batch, out_c*out_h*out_w, in_c*in_h*in_w)
A_matrix = patches_to_matrix(pieces, self.input_shape[1:], stride, padding, last_A.output_shape, last_A.unstable_idx)
- if isinstance(sum_bias, torch.Tensor) and last_A.unstable_idx is None:
+ # print(f'Converting patches to matrix: old shape {pieces.shape}, size {pieces.numel()}; new shape {A_matrix.shape}, size {A_matrix.numel()}')
+ if isinstance(sum_bias, Tensor) and last_A.unstable_idx is None:
sum_bias = sum_bias.transpose(0, 1)
sum_bias = sum_bias.reshape(sum_bias.size(0), -1).transpose(0,1)
A_matrix = A_matrix.transpose(0,1) # Spec dimension at the front.
return A_matrix, sum_bias
- return Patches(pieces, stride, padding, pieces.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape), sum_bias
+ # print(f'Conv returns patches with size={pieces.size()}, stride={stride}, padding={padding}, inserted_zeros={inserted_zeros}, output_padding={output_padding}')
+ return Patches(pieces, stride, padding, pieces.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape, inserted_zeros=last_A.inserted_zeros, output_padding=output_padding), sum_bias
else:
raise NotImplementedError()
@@ -152,37 +162,102 @@ def _bound_oneside(last_A):
uA_x, ubias = _bound_oneside(last_uA)
return [(lA_x, uA_x), (lA_y, uA_y), (lA_bias, uA_bias)], lbias, ubias
- def bound_forward(self, dim_in, *x):
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
if self.is_input_perturbed(1):
raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.")
- weight = x[1].lb
- bias = x[2].lb if self.has_bias else None
- x = x[0]
- input_dim = x.lb.shape[-2] * x.lb.shape[-1]
- wshape = x.lw.shape
- eye = torch.eye(input_dim).view(input_dim, 1, *x.lb.shape[-2:])
- weight = F.conv2d(eye, weight, None, self.stride, self.padding, self.dilation, self.groups)
- weight = weight.view(input_dim, -1)
- output_dim = weight.shape[-1]
- bias = bias.view(1, -1, 1).repeat(1, 1, output_dim // bias.shape[0]).view(*self.output_shape[1:])
- batch_size = x.lb.shape[0]
-
- lw = (x.lw.reshape(batch_size, dim_in, -1).matmul(weight.clamp(min=0)) +
- x.uw.reshape(batch_size, dim_in, -1).matmul(weight.clamp(max=0)))\
- .reshape(batch_size, dim_in, *self.output_shape[1:])
- uw = (x.uw.reshape(batch_size, dim_in, -1).matmul(weight.clamp(min=0)) +
- x.lw.reshape(batch_size, dim_in, -1).matmul(weight.clamp(max=0)))\
- .reshape(batch_size, dim_in, *self.output_shape[1:])
-
- lb = (x.lb.reshape(batch_size, -1).matmul(weight.clamp(min=0)) +
- x.ub.reshape(batch_size, -1).matmul(weight.clamp(max=0)))\
- .reshape(batch_size, *self.output_shape[1:]) + bias
- ub = (x.ub.reshape(batch_size, -1).matmul(weight.clamp(min=0)) +
- x.lb.reshape(batch_size, -1).matmul(weight.clamp(max=0)))\
- .reshape(batch_size, *self.output_shape[1:]) + bias
-
- return LinearBound(lw, lb, uw, ub)
+ assert self.dilation == (1, 1) or self.dilation == [1, 1]
+ # e.g., last layer input gurobi vars (3,32,32)
+ gvars_array = np.array(v[0])
+ # pre_layer_shape (1,3,32,32)
+ pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape
+ # this layer shape (1,8,16,16)
+ this_layer_shape = self.output_shape
+ out_lbs, out_ubs = None, None
+ if hasattr(self, "lower"):
+ # self.lower shape (1,8,16,16)
+ out_lbs = self.lower.cpu().numpy()
+ out_ubs = self.upper.cpu().numpy()
+
+ # current layer weight (8,3,4,4)
+ this_layer_weight = v[1].detach().cpu().numpy()
+ # current layer bias (8,)
+ this_layer_bias = None
+ if self.has_bias:
+ this_layer_bias = v[2].detach().cpu().numpy()
+ weight_shape2, weight_shape3 = this_layer_weight.shape[2], this_layer_weight.shape[3]
+ padding0, padding1 = self.padding[0], self.padding[1]
+ stride0, stride1 = self.stride[0], self.stride[1]
+
+ new_layer_gurobi_vars = []
+ new_layer_gurobi_constrs = []
+
+ neuron_idx = 0
+ for out_chan_idx in range(this_layer_shape[1]):
+ out_chan_vars = []
+ for out_row_idx in range(this_layer_shape[2]):
+ out_row_vars = []
+ for out_col_idx in range(this_layer_shape[3]):
+ # print(this_layer_bias.shape, out_chan_idx, out_lbs.size(1))
+ lin_expr = 0
+ if self.has_bias:
+ lin_expr = this_layer_bias[out_chan_idx]
+
+ for in_chan_idx in range(this_layer_weight.shape[1]):
+
+ # new version of conv layer for building mip by skipping kernel loops
+ ker_row_min, ker_row_max = 0, weight_shape2
+ in_row_idx_min = -padding0 + stride0 * out_row_idx
+ in_row_idx_max = in_row_idx_min + weight_shape2 - 1
+ if in_row_idx_min < 0:
+ ker_row_min = -in_row_idx_min
+ if in_row_idx_max >= pre_layer_shape[2]:
+ ker_row_max = ker_row_max - in_row_idx_max + pre_layer_shape[2] -1
+ in_row_idx_min, in_row_idx_max = max(in_row_idx_min, 0), min(in_row_idx_max, pre_layer_shape[2] - 1)
+
+ ker_col_min, ker_col_max = 0, weight_shape3
+ in_col_idx_min = -padding1 + stride1 * out_col_idx
+ in_col_idx_max = in_col_idx_min + weight_shape3 - 1
+ if in_col_idx_min < 0:
+ ker_col_min = -in_col_idx_min
+ if in_col_idx_max >= pre_layer_shape[3]:
+ ker_col_max = ker_col_max - in_col_idx_max + pre_layer_shape[3] -1
+ in_col_idx_min, in_col_idx_max = max(in_col_idx_min, 0), min(in_col_idx_max, pre_layer_shape[3] - 1)
+
+ coeffs = this_layer_weight[out_chan_idx, in_chan_idx, ker_row_min:ker_row_max, ker_col_min:ker_col_max].reshape(-1)
+
+ gvars = gvars_array[in_chan_idx, in_row_idx_min:in_row_idx_max+1, in_col_idx_min:in_col_idx_max+1].reshape(-1)
+ if solver_pkg == 'gurobi':
+ lin_expr += grb.LinExpr(coeffs, gvars)
+ else:
+ # lin_expr += coeffs@gvars
+
+ for i in range(len(coeffs)):
+ try:
+ lin_expr += coeffs[i] * gvars[i]
+ except TypeError:
+ lin_expr += coeffs[i] * gvars[i].var
+
+
+ out_lb = out_lbs[0, out_chan_idx, out_row_idx, out_col_idx] if out_lbs is not None else -float('inf')
+ out_ub = out_ubs[0, out_chan_idx, out_row_idx, out_col_idx] if out_ubs is not None else float('inf')
+ var = model.addVar(lb=out_lb, ub=out_ub,
+ obj=0, vtype=grb.GRB.CONTINUOUS,
+ # name=f'lay{layer_idx}_[{out_chan_idx}, {out_row_idx}, {out_col_idx}]')
+ name=f'lay{self.name}_{neuron_idx}')
+ # model.addConstr(lin_expr == var, name=f'lay{layer_idx}_[{out_chan_idx}, {out_row_idx}, {out_col_idx}]_eq')
+ # new_layer_gurobi_constrs.append(
+ # model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq'))
+ model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq')
+ neuron_idx += 1
+
+ out_row_vars.append(var)
+ out_chan_vars.append(out_row_vars)
+ new_layer_gurobi_vars.append(out_chan_vars)
+
+ self.solver_vars = new_layer_gurobi_vars
+ # self.solver_constrs = new_layer_gurobi_constrs
+ model.update()
def interval_propagate(self, *v, C=None):
if self.is_input_perturbed(1):
@@ -191,32 +266,6 @@ def interval_propagate(self, *v, C=None):
norm = Interval.get_perturbation(v[0])
norm = norm[0]
- if Interval.use_relative_bounds(*v):
- bias = v[2].nominal if self.has_bias else None
- if norm == np.inf:
- weight = v[1].nominal
- nominal = F.conv2d(
- v[0].nominal, weight, bias,
- self.stride, self.padding, self.dilation, self.groups)
- lower_offset = (F.conv2d(
- v[0].lower_offset, weight.clamp(min=0), None,
- self.stride, self.padding, self.dilation, self.groups) +
- F.conv2d(
- v[0].upper_offset, weight.clamp(max=0), None,
- self.stride, self.padding, self.dilation, self.groups))
- upper_offset = (F.conv2d(
- v[0].upper_offset, weight.clamp(min=0), None,
- self.stride, self.padding, self.dilation, self.groups) +
- F.conv2d(
- v[0].lower_offset, weight.clamp(max=0), None,
- self.stride, self.padding, self.dilation, self.groups))
- return Interval(
- None, None, nominal=nominal,
- lower_offset=lower_offset, upper_offset=upper_offset
- )
- else:
- raise NotImplementedError
-
h_L, h_U = v[0]
weight = v[1][0]
bias = v[2][0] if self.has_bias else None
@@ -247,15 +296,47 @@ def interval_propagate(self, *v, C=None):
ss = center.shape
deviation = deviation.repeat(ss[2] * ss[3]).view(-1, ss[1]).t().view(ss[1], ss[2], ss[3])
-
+
center = F.conv2d(mid, weight, bias, self.stride, self.padding, self.dilation, self.groups)
upper = center + deviation
lower = center - deviation
return lower, upper
+ def bound_dynamic_forward(self, *x, max_dim=None, offset=0):
+ if self.is_input_perturbed(1) or self.is_input_perturbed(2):
+ raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.")
+ weight = x[1].lb
+ bias = x[2].lb if self.has_bias else None
+ x = x[0]
+ w = x.lw
+ b = x.lb
+ shape = w.shape
+ shape_wconv = [shape[0] * shape[1]] + list(shape[2:])
+ def conv2d(input, weight, bias, stride, padding, dilation, groups):
+ """ There may be some CUDA error (illegal memory access) when
+ the batch size is too large. Thus split the input into several
+ batches when needed. """
+ max_batch_size = 50
+ if input.device != torch.device('cpu') and input.shape[0] > max_batch_size:
+ ret = []
+ for i in range((input.shape[0] + max_batch_size - 1) // max_batch_size):
+ ret.append(F.conv2d(
+ input[i*max_batch_size:(i+1)*max_batch_size],
+ weight, bias, stride, padding, dilation, groups))
+ return torch.cat(ret, dim=0)
+ else:
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
+ w_new = conv2d(
+ w.reshape(shape_wconv), weight, None, self.stride, self.padding,
+ self.dilation, self.groups)
+ w_new = w_new.reshape(shape[0], -1, *w_new.shape[1:])
+ b_new = conv2d(
+ b, weight, bias, self.stride, self.padding, self.dilation, self.groups)
+ return LinearBound(w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim)
+
def bound_forward(self, dim_in, *x):
- if self.is_input_perturbed(1):
+ if self.is_input_perturbed(1) or self.is_input_perturbed(2):
raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.")
weight = x[1].lb
@@ -270,16 +351,16 @@ def bound_forward(self, dim_in, *x):
shape = mid_w.shape
shape_wconv = [shape[0] * shape[1]] + list(shape[2:])
deviation_w = F.conv2d(
- diff_w.reshape(shape_wconv), weight_abs, None,
+ diff_w.reshape(shape_wconv), weight_abs, None,
self.stride, self.padding, self.dilation, self.groups)
deviation_b = F.conv2d(
- diff_b, weight_abs, None,
+ diff_b, weight_abs, None,
self.stride, self.padding, self.dilation, self.groups)
center_w = F.conv2d(
- mid_w.reshape(shape_wconv), weight, None,
+ mid_w.reshape(shape_wconv), weight, None,
self.stride, self.padding, self.dilation, self.groups)
- center_b = F.conv2d(
- mid_b, weight, bias,
+ center_b = F.conv2d(
+ mid_b, weight, bias,
self.stride, self.padding, self.dilation, self.groups)
deviation_w = deviation_w.reshape(shape[0], -1, *deviation_w.shape[1:])
center_w = center_w.reshape(shape[0], -1, *center_w.shape[1:])
@@ -290,164 +371,219 @@ def bound_forward(self, dim_in, *x):
uw = center_w + deviation_w,
ub = center_b + deviation_b)
+class BoundConvTranspose(Bound):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ assert (attr['pads'][0] == attr['pads'][2])
+ assert (attr['pads'][1] == attr['pads'][3])
-class BoundMaxPool(BoundOptimizableActivation):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2])
- assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3])
-
- self.nonlinear = True
- self.kernel_size = attr['kernel_shape']
self.stride = attr['strides']
self.padding = [attr['pads'][0], attr['pads'][1]]
- self.ceil_mode = False
- self.use_default_ibp = True
- self.alpha = None
- self.init = {}
-
- @Bound.save_io_shape
- def forward(self, x):
- output, _ = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode)
+ self.dilation = attr['dilations']
+ self.groups = attr['group']
+ self.output_padding = [attr.get('output_padding', [0, 0])[0], attr.get('output_padding', [0, 0])[1]]
+ if len(inputs) == 3:
+ self.has_bias = True
+ else:
+ self.has_bias = False
+ self.mode = options.get("conv_mode", "matrix")
+ assert self.output_padding == [0, 0]
+ assert self.padding == [0, 0]
+ assert self.dilation == [1, 1]
+ assert self.stride[0] == self.stride[1]
+ assert self.groups == 1
+
+ def forward(self, *x):
+ # x[0]: input, x[1]: weight, x[2]: bias if self.has_bias
+ bias = x[2] if self.has_bias else None
+ output = F.conv_transpose2d(x[0], x[1], bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding)
return output
- def project_simplex(self, patches):
- sorted = torch.flatten(patches, -2)
- sorted, _ = torch.sort(sorted, -1, descending=True)
- rho_sum = torch.cumsum(sorted, -1)
- rho_value = 1 - rho_sum
- rho_value = (sorted + rho_value/torch.tensor(range(1, sorted.size(-1)+1), dtype=torch.float, device=sorted.device)) > 0
- _, rho_index = torch.max(torch.cumsum(rho_value, -1), -1)
- rho_sum = torch.gather(rho_sum, -1, rho_index.unsqueeze(-1)).squeeze(-1)
- lbd = 1/(rho_index+1)* (1-rho_sum)
-
- return torch.clamp(patches + lbd.unsqueeze(-1).unsqueeze(-1), min=0)
-
- def init_opt_parameters(self, start_nodes):
- batch_size, channel, h, w = self.input_shape
- o_h, o_w = self.output_shape[-2:]
- # batch_size, out_c, out_h, out_w, k, k
-
- self.alpha = OrderedDict()
- ref = self.inputs[0].lower # a reference variable for getting the shape
- for ns, size_s in start_nodes:
- self.alpha[ns] = torch.empty([1, size_s, self.input_shape[0], self.input_shape[1], self.output_shape[-2], self.output_shape[-1], self.kernel_size[0], self.kernel_size[1]],
- dtype=torch.float, device=ref.device, requires_grad=True)
- self.init[ns] = False
-
- def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None, unstable_idx=None):
- paddings = tuple(self.padding + self.padding)
-
- A_shape = last_lA.shape if last_lA is not None else last_uA.shape
- # batch_size, input_c, x, y
- upper_d = torch.zeros((list(self.input_shape)), device=x.device)
- lower_d = torch.zeros((list(self.input_shape)), device=x.device)
-
- upper_d = F.pad(upper_d, paddings)
- lower_d = F.pad(lower_d, paddings)
-
- # batch_size, output_c, x, y
- upper_b = torch.zeros((list(self.output_shape)), device=x.device)
- lower_b = torch.zeros((list(self.output_shape)), device=x.device)
-
- # 1. find the index i where li > uj for all j, then set upper_d = lower_d = 1
- max_lower, max_lower_index = F.max_pool2d(x.lower, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode)
- delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape)
- max_upper, _ = F.max_pool2d(delete_upper, self.kernel_size, self.stride, 0, return_indices=True, ceil_mode=self.ceil_mode)
-
- values = torch.zeros_like(max_lower)
- values[max_lower >= max_upper] = 1.0
- upper_d = torch.scatter(torch.flatten(upper_d, -2), -1, torch.flatten(max_lower_index, -2), torch.flatten(values, -2)).view(upper_d.shape)
-
- if self.opt_stage == 'opt':
- if unstable_idx is not None and self.alpha[start_node.name].size(1) != 1:
- if unstable_idx.ndim == 1:
- # Only unstable neurons of the start_node neurons are used.
- alpha = self.non_deter_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1)
- elif unstable_idx.ndim == 2:
- # Each element in the batch selects different neurons.
- alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1)
- else:
- raise ValueError
- else:
- alpha = self.alpha[start_node.name]
-
- if self.init[start_node.name] == False:
- lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape)
- lower_d_unfold = F.unfold(lower_d, self.kernel_size, 1, stride=self.stride)
-
- alpha_data = lower_d_unfold.view(lower_d.shape[0], lower_d.shape[1], self.kernel_size[0], self.kernel_size[1], self.output_shape[-2], self.output_shape[-1])
- alpha.data.copy_(alpha_data.permute((0,1,4,5,2,3)).clone().detach())
- self.init[start_node.name] = True
- if self.padding[0] > 0:
- lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]]
-
- alpha.data = self.project_simplex(alpha.data).clone().detach()
- alpha = alpha.permute((0,1,2,3,6,7,4,5))
- alpha_shape = alpha.shape
- alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2], -1, alpha_shape[-2]*alpha_shape[-1]))
- lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1, self.padding, self.stride)
- lower_d = lower_d.view(alpha_shape[0], alpha_shape[1], alpha_shape[2], *lower_d.shape[1:])
- lower_d = lower_d.squeeze(0)
- else:
- lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape)
- if self.padding[0] > 0:
- lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]]
- values[:] = 0.0
- max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode)
- values[max_upper > max_lower] = max_upper_[max_upper > max_lower]
- upper_b = values
+ def bound_backward(self, last_lA, last_uA, *x):
+ if self.is_input_perturbed(1):
+ raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.")
+
+ lA_y = uA_y = lA_bias = uA_bias = None
+ weight = x[1].lower
+ assert weight.size(-1) == weight.size(-2)
- assert type(last_lA) == torch.Tensor or type(last_uA) == torch.Tensor
- def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg):
+ def _bound_oneside(last_A):
if last_A is None:
return None, 0
- pos_A = last_A.clamp(min=0)
- neg_A = last_A.clamp(max=0)
+ if type(last_A) is OneHotC:
+ # Conv layer does not support the OneHotC fast path. We have to create a dense matrix instead.
+ shape = last_A.shape # [spec, batch, C, H, W]
+ dim = int(prod(shape[2:]))
+ dense_last_A = torch.zeros(size=(shape[0], shape[1], dim), device=last_A.device, dtype=weight.dtype)
+ # last_A.index has size (spec, batch), its values are the index of the one-hot non-zero elements in A.
+ # last_A.coeffs is the value of the non-zero element.
+ dense_last_A = torch.scatter(dense_last_A, dim=2, index=last_A.index.unsqueeze(-1), src=last_A.coeffs.unsqueeze(-1))
+ # We created a large A matrix and it will be handled below.
+ last_A = dense_last_A.view(shape[0], shape[1], *shape[2:])
- bias = 0
- if b_pos is not None:
- bias = bias + self.get_bias(pos_A, b_pos)
- if b_neg is not None:
- bias = bias + self.get_bias(neg_A, b_neg)
+ if type(last_A) == Tensor:
+ shape = last_A.size()
+ next_A = F.conv2d(last_A.reshape(shape[0] * shape[1], *shape[2:]), weight, None,
+ stride=self.stride, padding=self.padding, dilation=self.dilation,
+ groups=self.groups)
+ next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:])
+ if self.has_bias:
+ sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2)
+ else:
+ sum_bias = 0
+ return next_A, sum_bias
+ elif type(last_A) == Patches:
+ # Here we build and propagate a Patch object with (patches, stride, padding)
+ assert type(last_A) == Patches
+ if last_A.identity == 0:
+ patches = last_A.patches
- shape = last_A.size()
- pos_A = F.interpolate(pos_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size)
- pos_A = F.pad(pos_A, (0, self.input_shape[-2] - pos_A.shape[-2], 0, self.input_shape[-1] - pos_A.shape[-1]))
- pos_A = pos_A.view(shape[0], shape[1], *pos_A.shape[1:])
+ # FIXME: so far, assume there will be a relu layer in its input.
- neg_A = F.interpolate(neg_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size)
- neg_A = F.pad(neg_A, (0, self.input_shape[-2] - neg_A.shape[-2], 0, self.input_shape[-1] - neg_A.shape[-1]))
- neg_A = neg_A.view(shape[0], shape[1], *neg_A.shape[1:])
+ if self.has_bias:
+ # bias is x[2] (lower and upper are the same), and has shape (c,).
+ # Patches either has [out_c, batch, out_h, out_w, c, h, w] or [unstable_size, batch, c, h, w].
+ sum_bias = torch.einsum('sb...chw,c->sb...', patches, x[2].lower)
+ # sum_bias has shape (out_c, batch, out_h, out_w) or (unstable_size, batch).
+ else:
+ sum_bias = 0
- next_A = pos_A * d_pos + neg_A * d_neg
- return next_A, bias
+ flattened_patches = patches.reshape(-1, patches.size(-3), patches.size(-2), patches.size(-1))
+ # Merge patches with this layer's weights. Weight must be flipped here; and if stride != 1, we must insert zeros in the input image.
+ # For conv_transpose2d, the weight matrix is in the (in, out, k, k) shape.
+ pieces = F.conv_transpose2d(flattened_patches, weight.transpose(0,1).flip(-1,-2), stride=last_A.inserted_zeros + 1)
+ # New patch size: (out_c, batch, out_h, out_w, c, h, w) or (unstable_size, batch, c, h, w).
+ pieces = pieces.view(*patches.shape[:-3], pieces.size(-3), pieces.size(-2), pieces.size(-1))
- if self.padding[0] > 0:
- upper_d = upper_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]]
+ elif last_A.identity == 1:
+ # New patches have size [out_c, batch, out_h, out_w, c, h, w] if it is not sparse.
+ # New patches have size [unstable_size, batch, c, h, w] if it is sparse.
+ if last_A.unstable_idx is not None:
+ raise NotImplementedError()
+ pieces = weight.view(weight.size(0), 1, weight.size(1), weight.size(2), weight.size(3))
+ # Select based on the output channel (out_h and out_w are irrelevant here).
+ pieces = pieces[last_A.unstable_idx[0]]
+ # Expand the batch dimnension.
+ pieces = pieces.expand(-1, last_A.shape[1], -1, -1, -1)
+ # Do the same for the bias.
+ sum_bias = x[2].lower[last_A.unstable_idx[0]].unsqueeze(-1)
+ # bias has shape (unstable_size, batch).
+ sum_bias = sum_bias.expand(-1, last_A.shape[1])
+ else:
+ assert weight.size(0) == last_A.shape[0]
+ pieces = weight.view(weight.size(0), 1, 1, 1, weight.size(1), weight.size(2), weight.size(3)).expand(-1, *last_A.shape[1:4], -1, -1, -1)
+ # The bias (x[2].lower) has shape (out_c,) need to make it (out_c, batch, out_h, out_w).
+ # Here we should transpose sum_bias to set the batch dim to 1, aiming to keep it consistent with the matrix version
+ sum_bias = x[2].lower.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4])
+ else:
+ raise NotImplementedError()
+ padding = last_A.padding if last_A is not None else (0, 0, 0, 0) # (left, right, top, bottom)
+ output_padding = last_A.output_padding if last_A is not None else (0, 0, 0, 0) # (left, right, top, bottom)
+ inserted_zeros = last_A.inserted_zeros
+ assert self.padding == [0, 0]
+ assert self.stride[0] == self.stride[1]
+
+ # Unify the shape to 4-tuple.
+ output_padding = unify_shape(output_padding)
+ padding = unify_shape(padding)
+ this_stride = unify_shape(self.stride)
+ this_padding = unify_shape(self.padding)
+
+ # Compute new padding.
+ padding = tuple(p + (weight.size(3 - j//2) - 1) for j, p in enumerate(padding))
+
+ # Compute new output padding
+ output_padding = tuple(p * this_stride[j] + this_padding[j] for j, p in enumerate(output_padding))
+ # When we run insert_zeros, it's missing the right most column and the bottom row.
+ # padding = (padding[0], padding[1] + inserted_zeros, padding[2], padding[3] + inserted_zeros)
+
+ # If no transposed conv so far, inserted_zero is 0.
+ # We a transposed conv is encountered, stride is multiplied on it.
+ inserted_zeros = (inserted_zeros + 1) * this_stride[0] - 1
+
+ # FIXME: disabled patches_to_matrix because not all parameters are supported.
+ if inserted_zeros == 0 and not is_shape_used(output_padding) and pieces.shape[-1] > self.input_shape[-1]: # the patches is too large and from now on, we will use matrix mode instead of patches mode.
+ # This is our desired matrix: the input will be flattend to (batch_size, input_channel*input_x * input_y) and multiplies on this matrix.
+ # After multiplication, the desired output is (batch_size, out_channel*output_x*output_y).
+ # A_matrix has size (batch, out_c*out_h*out_w, in_c*in_h*in_w)
+ assert inserted_zeros == 0
+ A_matrix = patches_to_matrix(pieces, self.input_shape[1:], last_A.stride, padding, last_A.output_shape, last_A.unstable_idx)
+ if isinstance(sum_bias, Tensor) and last_A.unstable_idx is None:
+ sum_bias = sum_bias.transpose(0, 1)
+ sum_bias = sum_bias.reshape(sum_bias.size(0), -1).transpose(0,1)
+ A_matrix = A_matrix.transpose(0,1) # Spec dimension at the front.
+ return A_matrix, sum_bias
+ return Patches(pieces, last_A.stride, padding, pieces.shape, unstable_idx=last_A.unstable_idx,
+ output_shape=last_A.output_shape, inserted_zeros=inserted_zeros, output_padding=output_padding), sum_bias
+ else:
+ raise NotImplementedError()
+
+ lA_x, lbias = _bound_oneside(last_lA)
+ uA_x, ubias = _bound_oneside(last_uA)
+ return [(lA_x, uA_x), (lA_y, uA_y), (lA_bias, uA_bias)], lbias, ubias
+
+ def interval_propagate(self, *v, C=None):
+ if self.is_input_perturbed(1):
+ raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.")
+
+ norm = Interval.get_perturbation(v[0])
+ norm = norm[0]
+
+ h_L, h_U = v[0]
+ weight = v[1][0]
+ bias = v[2][0] if self.has_bias else None
+
+ if norm == np.inf:
+ mid = (h_U + h_L) / 2.0
+ diff = (h_U - h_L) / 2.0
+ weight_abs = weight.abs()
+ deviation = F.conv_transpose2d(diff, weight_abs, None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding)
+ elif norm > 0:
+ raise NotImplementedError()
+ norm, eps = Interval.get_perturbation(v[0])
+ # L2 norm, h_U and h_L are the same.
+ mid = h_U
+ # TODO: padding
+ deviation = torch.mul(weight, weight).sum((1, 2, 3)).sqrt() * eps
+ deviation = deviation.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ else: # Here we calculate the L0 norm IBP bound using the bound proposed in [Certified Defenses for Adversarial Patches, ICLR 2020]
+ raise NotImplementedError()
+ norm, eps, ratio = Interval.get_perturbation(v[0])
+ mid = h_U
+ k = int(eps)
+ weight_sum = torch.sum(weight.abs(), 1)
+ deviation = torch.sum(torch.topk(weight_sum.view(weight_sum.shape[0], -1), k)[0], dim=1) * ratio
- uA, ubias = _bound_oneside(last_uA, upper_d, lower_d, upper_b, lower_b)
- lA, lbias = _bound_oneside(last_lA, lower_d, upper_d, lower_b, upper_b)
+ if self.has_bias:
+ center = F.conv2d(mid, weight, v[2][0], self.stride, self.padding, self.dilation, self.groups)
+ else:
+ center = F.conv2d(mid, weight, None, self.stride, self.padding, self.dilation, self.groups)
- return [(lA, uA)], lbias, ubias
+ ss = center.shape
+ deviation = deviation.repeat(ss[2] * ss[3]).view(-1, ss[1]).t().view(ss[1], ss[2], ss[3])
+
+ center = F.conv_transpose2d(mid, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding)
+
+ upper = center + deviation
+ lower = center - deviation
+ return lower, upper
class BoundPad(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- if len(attr) == 1:
- self.padding = [0, 0, 0, 0]
- self.value = 0.0
- else:
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ if hasattr(attr, 'pads'):
self.padding = attr['pads'][2:4] + attr['pads'][6:8]
- self.value = attr['value']
+ else:
+ self.padding = [0, 0, 0, 0]
+ self.value = attr.get('value', 0.0)
assert self.padding == [0, 0, 0, 0]
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- @Bound.save_io_shape
def forward(self, x, pad, value=0.0):
# TODO: padding for 3-D or more dimensional inputs.
assert x.ndim == 4
# x[1] should be [0,0,pad_top,pad_left,0,0,pad_bottom,pad_right]
+ assert pad[0] == pad[1] == pad[4] == pad[5] == 0
pad = [int(pad[3]), int(pad[7]), int(pad[2]), int(pad[6])]
final = F.pad(x, pad, value=value)
self.padding, self.value = pad, value
@@ -459,7 +595,6 @@ def interval_propagate(self, *v):
def bound_backward(self, last_lA, last_uA, *x):
# TODO: padding for 3-D or more dimensional inputs.
- pad = self.padding
left, right, top, bottom = self.padding
def _bound_oneside(last_A):
if last_A is None:
@@ -470,7 +605,7 @@ def _bound_oneside(last_A):
new_padding = (last_A.padding[0] + left, last_A.padding[1] + right, last_A.padding[2] + top, last_A.padding[3] + bottom)
else:
new_padding = (last_A.padding + left, last_A.padding + right, last_A.padding + top, last_A.padding + bottom)
- return Patches(last_A.patches, last_A.stride, new_padding, last_A.shape, last_A.identity, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape)
+ return last_A.create_similar(padding=new_padding)
else:
shape = last_A.size()
return last_A[:, :, :, top:(shape[3] - bottom), left:(shape[4] - right)]
@@ -478,61 +613,42 @@ def _bound_oneside(last_A):
last_uA = _bound_oneside(last_uA)
return [(last_lA, last_uA), (None, None), (None, None)], 0, 0
-class BoundGlobalAveragePool(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
-
- @Bound.save_io_shape
- def forward(self, x):
- output = nn.AdaptiveAvgPool2d((1, 1)).forward(x) # adaptiveAveragePool with output size (1, 1)
- return output
-
- def bound_backward(self, last_lA, last_uA, x):
- H, W = self.input_shape[-2], self.input_shape[-1]
-
- lA = (last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)) if last_lA is not None else None
- uA = (last_uA.expand(list(last_uA.shape[:-2]) + [H, W]) / (H * W)) if last_uA is not None else None
-
- return [(lA, uA)], 0, 0
-
- def interval_propagate(self, *v):
- h_L, h_U = v[0]
- h_L = F.adaptive_avg_pool2d(h_L, (1, 1))
- h_U = F.adaptive_avg_pool2d(h_U, (1, 1))
- return h_L, h_U
-
-class BoundAveragePool(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- # assumptions: ceil_mode=False, count_include_pad=True
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
-
- assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2])
- assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3])
-
- self.kernel_size = attr['kernel_shape']
- self.stride = attr['strides']
- self.padding = [attr['pads'][0], attr['pads'][1]]
- self.ceil_mode = False
- self.count_include_pad = True
- self.use_default_ibp = True
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ # e.g., last layer input gurobi vars (3,32,32)
+ gvars_array = np.array(v[0])
+ # pre_layer_shape (1,3,32,32)
+ pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape
+ # this layer shape (1,3,35,35)
+ this_layer_shape = self.output_shape
+ # v1 = tensor([0, 0, 1, 1, 0, 0, 2, 2])
+ # [0,0,pad_top,pad_left,0,0,pad_bottom,pad_right]
+ # => [left, right, top, bottom]
+ padding = [int(v[1][3]), int(v[1][7]), int(v[1][2]), int(v[1][6])]
+ left, right, top, bottom = padding
+ assert pre_layer_shape[2] + padding[0] + padding[1] == this_layer_shape[2]
+ assert pre_layer_shape[3] + padding[2] + padding[3] == this_layer_shape[3]
+
+ new_layer_gurobi_vars = []
+ neuron_idx = 0
+ for out_chan_idx in range(this_layer_shape[1]):
+ out_chan_vars = []
+ for out_row_idx in range(this_layer_shape[2]):
+ out_row_vars = []
+ row_pad = out_row_idx < left or out_row_idx >= this_layer_shape[2] - right
+ for out_col_idx in range(this_layer_shape[3]):
+ col_pad = out_col_idx < top or out_col_idx >= this_layer_shape[3] - bottom
+ if row_pad or col_pad:
+ v = model.addVar(lb=0, ub=0,
+ obj=0, vtype=grb.GRB.CONTINUOUS,
+ name=f'pad{self.name}_{neuron_idx}')
+ else:
+ v = gvars_array[out_chan_idx, out_row_idx - left, out_col_idx - top]
+ # print(out_chan_idx, out_row_idx, out_col_idx, row_pad, col_pad, v.LB, v.UB)
+ neuron_idx += 1
- @Bound.save_io_shape
- def forward(self, x):
- return F.avg_pool2d(x, self.kernel_size, self.stride,
- self.padding, self.ceil_mode, self.count_include_pad)
+ out_row_vars.append(v)
+ out_chan_vars.append(out_row_vars)
+ new_layer_gurobi_vars.append(out_chan_vars)
- def bound_backward(self, last_lA, last_uA, x):
- def _bound_oneside(last_A):
- if last_A is None:
- return None, 0
- shape = last_A.size()
- # propagate A to the next layer, with batch concatenated together
- next_A = F.interpolate(last_A.view(shape[0] * shape[1], *shape[2:]),
- scale_factor=self.kernel_size) / (prod(self.kernel_size))
- next_A = F.pad(next_A, (0, self.input_shape[-2] - next_A.shape[-2], 0, self.input_shape[-1] - next_A.shape[-1]))
- next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:])
- return next_A, 0
-
- lA, lbias = _bound_oneside(last_lA)
- uA, ubias = _bound_oneside(last_uA)
- return [(lA, uA)], lbias, ubias
\ No newline at end of file
+ self.solver_vars = new_layer_gurobi_vars
+ model.update()
diff --git a/auto_LiRPA/operators/cut_ops.py b/auto_LiRPA/operators/cut_ops.py
new file mode 100644
index 0000000..e4337ae
--- /dev/null
+++ b/auto_LiRPA/operators/cut_ops.py
@@ -0,0 +1,588 @@
+""" Cut operators"""
+from .base import *
+from .clampmult import multiply_by_A_signs
+
+
+class CutModule():
+ # store under BoundedModule
+ def __init__(self, relu_nodes=[], general_beta=None, x_coeffs=None,
+ active_cuts=None, cut_bias=None):
+ # all dict, storing cut parameters for each start node
+ # {start node name: (2 (lA, uA), spec (out_c, out_h, out_w), batch, num_cuts)}
+ self.general_beta = general_beta
+ # {start node name: (# active cut constraints,)}
+ self.active_cuts = active_cuts
+
+ # all dict with tensor, storing coeffs for each relu layer, no grad
+ # coeffs: {relu layername: (num_cuts, flattened_nodes)}
+ self.relu_coeffs, self.arelu_coeffs, self.pre_coeffs = {}, {}, {}
+ for m in relu_nodes:
+ self.relu_coeffs[m.name] = self.arelu_coeffs[m.name] = self.pre_coeffs[m.name] = None
+
+ # single tensor, always the same, no grad
+ # bias: (num_cuts,)
+ self.cut_bias = cut_bias
+ # x_coeffs: (num_cuts, flattened input dims)
+ self.x_coeffs = x_coeffs
+
+ def use_patches(self, start_node):
+ # check if we are using patches mode for the start node
+ A = start_node.lA if start_node.lA is not None else start_node.uA
+ return type(A) is Patches
+
+ def select_active_general_beta(self, start_node, unstable_idx=None):
+ # if one constraint have nodes deeper than start node, we do not count its effect for now
+ # self.general_beta[start_node.name]: (2(0 lower, 1 upper), spec (out_c, out_h, out_w/# fc nodes), batch, num_constrs)
+ # self.active_cuts[start_node.name]: a long() tensor with constraint index that
+ # should be index on current layer with current start node
+ if self.general_beta[start_node.name].ndim == 4:
+ general_beta = self.general_beta[start_node.name][:, :, :, self.active_cuts[start_node.name]]
+ elif self.general_beta[start_node.name].ndim == 6:
+ general_beta = self.general_beta[start_node.name][:, :, :, :, :, self.active_cuts[start_node.name]]
+ else:
+ print("general beta shape not supported!")
+ exit()
+ if unstable_idx is not None:
+ if self.use_patches(start_node):
+ general_beta = general_beta[:, unstable_idx[0], unstable_idx[1], unstable_idx[2], :, :]
+ else:
+ # matrix mode
+ if general_beta.ndim == 6:
+ # conv layers general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs)
+ _, out_c, out_h, out_w, batch, num_constrs = general_beta.shape
+ general_beta = general_beta.view(2, -1, batch, num_constrs)
+ else:
+ # dense layers general_beta: (2(0 lower, 1 upper), spec, batch, num_constrs)
+ pass
+ general_beta = general_beta[:, unstable_idx]
+ else:
+ # unstable_idx is None
+ if general_beta.ndim == 6:
+ # flatten spec layer shape
+ _, out_c, out_h, out_w, batch, num_constrs = general_beta.shape
+ general_beta = general_beta.view(2, -1, batch, num_constrs)
+ return general_beta
+
+ def general_beta_coeffs_mm(self, unstable_spec_beta, coeffs, A, current_layer_shape):
+ if type(A) is Patches:
+ # lA, uA are patches, we have to unfold beta and coeffs to match lA and uA
+ # coeffs: (num_constrs, current_c, current_h, current_w)
+ # coeffs_unfolded: (num_constrs, out_h, out_w, in_c, H, W)
+ # current_layer_shape = x.lower.size()[1:]
+ coeffs_unfolded = inplace_unfold(coeffs.view(-1, *current_layer_shape), \
+ kernel_size=A.patches.shape[-2:], padding=A.padding, stride=A.stride)
+ # unstable_coeffs_unfolded: (num_constrs, unstable, in_c, H, W)
+ # A.unstable_idx is the unstable idx for spec layer
+ unstable_coeffs_unfolded = coeffs_unfolded[:, A.unstable_idx[1], A.unstable_idx[2], :, :, :]
+ # A.unstable_idx: unstable index on out_c, out_h and out_w
+ # general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs)
+ # unstable_spec_beta: (2(0 lower, 1 upper), unstable, batch, num_constrs)
+ # unstable_spec_beta = general_beta[:, A.unstable_idx[0],\
+ # A.unstable_idx[1], A.unstable_idx[2], :, :]
+ # beta_mm_coeffs_unfolded: (2(0 lower, 1 upper), unstable, batch, in_c, H, W)
+ beta_mm_coeffs = torch.einsum('sihj,jiabc->sihabc', unstable_spec_beta, unstable_coeffs_unfolded)
+ else:
+ # unstable_spec_beta: (2(0 lower, 1 upper), unstable, batch, num_constrs)
+ # coeffs: (num_constrs, current flattened layer nodes)
+ # beta_mm_coeffs: (2(0 lower, 1 upper), unstable, batch, current flattened layer nodes)
+ beta_mm_coeffs = torch.einsum('sihj,ja->siha', unstable_spec_beta, coeffs)
+ assert beta_mm_coeffs[0].numel() == A.numel(), f"the shape of beta is not initialized correctly! {beta_mm_coeffs[0].shape} v.s. {A.shape}"
+ return beta_mm_coeffs.reshape(2, *A.shape)
+
+ def general_beta_coeffs_addmm_to_A(self, lA, uA, general_beta, coeffs, current_layer_shape):
+ A = lA if lA is not None else uA
+ # general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs)
+ # coeffs: (num_constrs, current_c, current_h, current_w)
+ # beta_mm_coeffs[0] shape is the same as A
+ # patches mode: (2(0 lower, 1 upper), unstable, batch, in_c, H, W)
+ # not patches: (2(0 lower, 1 upper), unstable, batch, current flattened layer nodes)
+ beta_mm_coeffs = self.general_beta_coeffs_mm(general_beta, coeffs, A, current_layer_shape)
+ assert beta_mm_coeffs[0].shape == A.shape
+ if type(A) is Patches:
+ # lA, uA are patches, we have to unfold beta and coeffs to match lA and uA
+ # lA_patches: (unstable, batch, in_c, H, W)
+ if lA is not None:
+ lA = Patches(lA.patches - beta_mm_coeffs[0], A.stride, A.padding, \
+ A.patches.shape, unstable_idx=A.unstable_idx, output_shape=A.output_shape)
+ if uA is not None:
+ uA = Patches(uA.patches + beta_mm_coeffs[1], A.stride, A.padding, \
+ A.patches.shape, unstable_idx=A.unstable_idx, output_shape=A.output_shape)
+ else:
+ # dense layers
+ if lA is not None:
+ lA = lA - beta_mm_coeffs[0]
+ if uA is not None:
+ uA = uA + beta_mm_coeffs[1]
+ return lA, uA
+
+ def patch_trick(self, start_node, layer_name, A, current_layer_shape):
+ ######## A problem with patches mode for cut constraint start ##########
+ # There are cases that the node that is in the constraint but not selected by the patches for the output node
+ # trick: only count the small patches that have all the split node coeffs[ci].sum() equal to coeffs_unfolded[ci][out_h, out_w, -1].sum()
+ # we should force these beta to be 0 to disable the effect of these constraints
+ # this only apply if current layer uses patches mode; if the target layer is patches but current layer not, we should not use it!
+ assert type(A) is Patches, "this trick fix only works for patches mode"
+ # unstable_spec_beta stores the current propagation, self.general_beta[start_node.name] selected with active_cuts, spec unstable
+ coeffs = 0
+ if layer_name != "input":
+ if self.relu_coeffs[layer_name] is not None:
+ coeffs = coeffs + self.relu_coeffs[layer_name]
+ if self.arelu_coeffs[layer_name] is not None:
+ coeffs = coeffs + self.arelu_coeffs[layer_name]
+ if self.pre_coeffs[layer_name] is not None:
+ coeffs = coeffs + self.pre_coeffs[layer_name]
+ else:
+ if self.x_coeffs is not None:
+ coeffs = coeffs + self.x_coeffs
+ coeffs_unfolded = inplace_unfold(coeffs.view(-1, *current_layer_shape), \
+ kernel_size=A.patches.shape[-2:], padding=A.padding, stride=A.stride)
+ num_constrs, out_h, out_w, in_c, H, W = coeffs_unfolded.shape
+ # make sure the small patch selected include all the nonzero coeffs
+ ####### NOTE: This check could be costly #######
+ patch_mask_on_beta = (coeffs_unfolded.reshape(num_constrs, out_h, out_w, -1).abs().sum(-1) == \
+ coeffs.reshape(num_constrs, -1).abs().sum(-1).reshape(num_constrs, 1, 1))
+ # patch_mask_on_beta: (out_h, out_w, num_constrs)
+ patch_mask_on_beta = patch_mask_on_beta.permute(1, 2, 0)
+ # 2(lower, upper), out_c, out_h, out_w, batch, num_constrs
+ patch_mask_on_beta = patch_mask_on_beta.reshape(1, 1, out_h, out_w, 1, num_constrs)
+ self.general_beta[start_node.name].data = self.general_beta[start_node.name].data * patch_mask_on_beta
+
+ def relu_cut(self, start_node, layer_name, last_lA, last_uA, current_layer_shape, unstable_idx=None, batch_mask=None):
+ # propagate relu neuron in cut constraints through relu layer
+ # start_node.name in self.general_beta means there are intermediate betas that can optimize this start node separately
+ relu_coeffs = self.relu_coeffs[layer_name]
+ active_cuts = self.active_cuts[start_node.name]
+ # active_cuts.size(0) == 0 means all constraints containing this layer have deep layer nodes
+ if relu_coeffs is None or active_cuts.size(0) == 0:
+ # do nothing
+ return last_lA, last_uA
+ assert start_node.name in self.general_beta
+ # select current relu layer general beta
+ general_beta = self.select_active_general_beta(start_node, unstable_idx)
+ relu_coeffs = relu_coeffs[active_cuts]
+ if batch_mask is not None:
+ general_beta = general_beta[:, :, batch_mask]
+ last_lA, last_uA = self.general_beta_coeffs_addmm_to_A(last_lA, last_uA, general_beta,
+ relu_coeffs, current_layer_shape)
+ return last_lA, last_uA
+
+ def pre_cut(self, start_node, layer_name, lA, uA, current_layer_shape, unstable_idx=None, batch_mask=None):
+ # propagate prerelu neuron in cut constraints through relu layer
+ # start_node.name in self.general_beta means there are intermediate betas that can optimize this start node separately
+ pre_coeffs = self.pre_coeffs[layer_name]
+ active_cuts = self.active_cuts[start_node.name]
+ # active_cuts.size(0) == 0 means all constraints containing this layer have deep layer nodes
+ if pre_coeffs is None or active_cuts.size(0) == 0:
+ # do nothing
+ return lA, uA
+ general_beta = self.select_active_general_beta(start_node, unstable_idx)
+ pre_coeffs = pre_coeffs[active_cuts]
+ if batch_mask is not None:
+ general_beta = general_beta[:, :, batch_mask]
+ lA, uA = self.general_beta_coeffs_addmm_to_A(lA, uA, general_beta, pre_coeffs, current_layer_shape)
+ return lA, uA
+
+
+ @staticmethod
+ @torch.jit.script
+ def jit_arelu_lA(last_lA, lower, upper, beta_mm_coeffs, unstable_or_cut_index, upper_d):
+ nu_hat_pos = last_lA.clamp(max=0.).abs()
+ tao = (-lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[0]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10)
+ pi = (upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[0]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10)
+ tao, pi = tao.clamp(min=0.), pi.clamp(min=0.)
+ tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos)
+ new_upper_d = pi / (pi + tao + 1e-10)
+ # need to customize the upper bound slope and lbias for (1) unstable relus and
+ # (2) relus that are used with upper boundary relaxation
+ # original upper bound slope is u/(u-l) also equal to pi/(pi+tao) if no beta_mm_coeffs[0]
+ # now the upper bound slope should be pi/(p+tao) updated with beta_mm_coeffs[0]
+ unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(last_lA < 0)
+ # conv layer:
+ # upper_d: 1, batch, current_c, current_w, current_h
+ # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current_c, current_w, current_h
+ # dense layer:
+ # upper_d: 1, batch, current flattened nodes
+ # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current flattened nodes
+ new_upper_d = new_upper_d * unstable_upper_bound_index.float() + \
+ upper_d * (1. - unstable_upper_bound_index.float())
+ return nu_hat_pos, tao, pi, new_upper_d, unstable_upper_bound_index
+
+ @staticmethod
+ @torch.jit.script
+ def jit_arelu_lbias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, lower, upper, lbias, pi, tao):
+ # if no unstable, following bias should always be 0
+ if unstable_or_cut_index.sum() > 0:
+ # update lbias with new form, only contribued by unstable relus
+ uC = -upper.unsqueeze(0) * nu_hat_pos
+ lC = -lower.unsqueeze(0) * nu_hat_pos
+ # lbias: (spec unstable, batch, current flattened nodes) same as lA
+ lbias = (pi * lower.unsqueeze(0))
+
+ uC_mask = (beta_mm_coeffs[0] <= uC).to(lbias)
+ lC_mask = (beta_mm_coeffs[0] >= lC).to(lbias)
+ default_mask = ((1-uC_mask) * (1-lC_mask)).to(lbias)
+ lbias = - beta_mm_coeffs[0].to(lbias) * lC_mask + lbias * default_mask
+
+ # lbias[beta_mm_coeffs[0] <= uC] = 0.
+ # lbias[beta_mm_coeffs[0] >= lC] = -beta_mm_coeffs[0][beta_mm_coeffs[0] >= lC].to(lbias)
+
+ # final lbias: (spec unstable, batch)
+ lbias = (lbias * unstable_or_cut_index.unsqueeze(0).float()).view(lbias.shape[0], lbias.shape[1], -1).sum(-1)
+ return lbias
+
+ @staticmethod
+ @torch.jit.script
+ def jit_arelu_uA(last_uA, lower, upper, beta_mm_coeffs, unstable_or_cut_index, upper_d):
+ nu_hat_pos = (-last_uA).clamp(max=0.).abs()
+ tao = (- lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[1]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10)
+ pi = (upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[1]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10)
+ tao, pi = tao.clamp(min=0.), pi.clamp(min=0.)
+ tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos)
+ new_upper_d = pi / (pi + tao + 1e-10)
+
+ # assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos"
+
+ # unstable_or_cut_index = self.I.logical_or(self.arelu_coeffs.sum(0).view(self.I.shape) != 0)
+ unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(-last_uA < 0)
+ new_upper_d = new_upper_d * unstable_upper_bound_index.float() + \
+ upper_d * (1. - unstable_upper_bound_index.float())
+ return nu_hat_pos, tao, pi, new_upper_d, unstable_upper_bound_index
+
+ @staticmethod
+ @torch.jit.script
+ def jit_arelu_ubias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, lower, upper, ubias, pi, tao):
+ if unstable_or_cut_index.sum() > 0:
+ uC = -upper.unsqueeze(0) * nu_hat_pos
+ lC = -lower.unsqueeze(0) * nu_hat_pos
+ ubias = -(pi * lower.unsqueeze(0))
+
+ uC_mask = (beta_mm_coeffs[1] <= uC).to(ubias)
+ lC_mask = (beta_mm_coeffs[1] >= lC).to(ubias)
+ default_mask = ((1-uC_mask) * (1-lC_mask)).to(ubias)
+ ubias = beta_mm_coeffs[1].to(ubias) * lC_mask + ubias * default_mask
+
+ # ubias[beta_mm_coeffs[1] <= uC] = 0.
+ # ubias[beta_mm_coeffs[1] >= lC] = beta_mm_coeffs[1][beta_mm_coeffs[1] >= lC].to(ubias)
+
+ ubias = (ubias * unstable_or_cut_index.unsqueeze(0).float()).view(ubias.shape[0], ubias.shape[1], -1).sum(-1)
+ return ubias
+
+
+ def arelu_cut(self, start_node, layer_name, last_lA, last_uA, lower_d, upper_d,
+ lower_b, upper_b, lb_lower_d, ub_lower_d, I, x, patch_size,
+ current_layer_shape, unstable_idx=None, batch_mask=None):
+ # propagate integer var of relu neuron (arelu) in cut constraints through relu layer
+ arelu_coeffs = self.arelu_coeffs[layer_name]
+ active_cuts = self.active_cuts[start_node.name]
+ # active_cuts.size(0) == 0 means all constraints containing this layer have deep layer nodes
+ if arelu_coeffs is None or active_cuts.size(0) == 0:
+ # do regular propagation without cut
+ uA, ubias = _bound_oneside(last_uA, upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, lower_b, start_node, patch_size)
+ lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, upper_d, lower_b, upper_b, start_node, patch_size)
+ return lA, uA, lbias, ubias
+
+ # general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs)
+ general_beta = self.select_active_general_beta(start_node, unstable_idx)
+ # arelu_coeffs: (num_constrs, flattened current layer nodes)
+ arelu_coeffs = arelu_coeffs[active_cuts]
+ if batch_mask is not None:
+ general_beta = general_beta[:, :, batch_mask]
+ A = last_lA if last_lA is not None else last_uA
+ # beta_mm_coeffs[0] shape is the same as A
+ # patches mode: (2(0 lower, 1 upper), unstable, batch, in_c, H, W)
+ # not patches: (2(0 lower, 1 upper), unstable, batch, current flattened layer nodes)
+ beta_mm_coeffs = self.general_beta_coeffs_mm(general_beta, arelu_coeffs, A, current_layer_shape)
+ # unstable_this_layer = torch.logical_and(x.lower < 0, x.upper > 0).unsqueeze(0)
+ # I is the unstable index in this relu layer: (batch, *layer shape)
+ # if there is node in cut constraint that is stable, also need to count its effect
+ # self.arelu_coeffs: (num_constrs, flattened current layer)
+ unstable_or_cut_index = I.logical_or(arelu_coeffs.sum(0).view(I[0:1].shape) != 0)
+
+ if type(A) is Patches:
+ # patches mode, conv layer only
+ # x.lower (always regular shape): batch, current_c, current_h, current_w
+ # x_lower_unfold: unstable, batch, in_C, H, W (same as patches last_lA)
+ x_lower_unfold = _maybe_unfold(x.lower.unsqueeze(0), A)
+ x_upper_unfold = _maybe_unfold(x.upper.unsqueeze(0), A)
+ # first minus upper and lower and then unfold to patch size will save memory
+ x_upper_minus_lower_unfold = _maybe_unfold((x.upper - x.lower).unsqueeze(0), A)
+ ####### be careful with the unstable_this_layer and unstable_idx #######
+ # unstable_this_layer is the unstable index in this layer
+ # unstable_idx is the unstable index in spec layer
+ # unstable_this_layer: spec unstable, batch, in_C, H, W (same as patches last_lA)
+ # unstable_this_layer = torch.logical_and(x_lower_unfold < 0, x_upper_unfold > 0)
+ # unstable_this_layer = _maybe_unfold(self.I.unsqueeze(0), A)
+ unstable_or_cut_index = _maybe_unfold(unstable_or_cut_index.unsqueeze(0), A)
+ if last_lA is not None:
+ assert beta_mm_coeffs[0].shape == last_lA.shape, f"{beta_mm_coeffs[0].shape} != {last_lA.shape}"
+ # last_lA.patches, nu_hat_pos, tao, pi: (unstable, batch, in_c, H, W)
+ nu_hat_pos = last_lA.patches.clamp(max=0.).abs()
+ tao = (-x_lower_unfold * nu_hat_pos - beta_mm_coeffs[0]) / (x_upper_minus_lower_unfold.clamp(min=1e-10))
+ pi = (x_upper_unfold * nu_hat_pos + beta_mm_coeffs[0]) / (x_upper_minus_lower_unfold.clamp(min=1e-10))
+ tao, pi = tao.clamp(min=0.), pi.clamp(min=0.)
+ tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos)
+ new_upper_d = pi / (pi + tao + 1e-10)
+
+ assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos"
+
+ # unstable_upper_bound_index: spec unstable, batch, in_C, H, W (same as patches last_lA)
+ unstable_upper_bound_index = unstable_or_cut_index.logical_and(last_lA.patches < 0)
+ # upper_d: (spec unstable, 1, in_C, H, W) (unfolded shape, same as patches last_lA)
+ new_upper_d = new_upper_d * unstable_upper_bound_index.float() + \
+ upper_d * (1. - unstable_upper_bound_index.float())
+
+ if last_uA is None: uA, ubias = None, 0.
+ # lbias: unstable, batch
+ # lA: unstable, batch, in_C, H, W (same as patches last_lA)
+ lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, new_upper_d, lower_b, upper_b, start_node, patch_size)
+
+ # if general_beta[0].sum()!=0: import pdb; pdb.set_trace()
+ # there is any unstable relus in this layer
+ if unstable_or_cut_index.sum() > 0:
+ uC = -x_upper_unfold * nu_hat_pos
+ lC = -x_lower_unfold * nu_hat_pos
+ lbias = (pi * x_lower_unfold)
+ lbias[beta_mm_coeffs[0] <= uC] = 0.
+ lbias[beta_mm_coeffs[0] >= lC] = -beta_mm_coeffs[0][beta_mm_coeffs[0] >= lC].to(lbias)
+ # lbias: unstable, batch, in_C, H, W (same as patches last_lA) => lbias: (unstable, batch)
+ lbias = (lbias * unstable_or_cut_index.float()).view(lbias.shape[0], lbias.shape[1], -1).sum(-1)
+
+ if last_uA is not None:
+ # get the upper bound
+ nu_hat_pos = (-last_uA.patches).clamp(max=0.).abs()
+ tao = (-x_lower_unfold * nu_hat_pos - beta_mm_coeffs[1]) / (x_upper_minus_lower_unfold + 1e-10)
+ pi = (x_upper_unfold * nu_hat_pos + beta_mm_coeffs[1]) / (x_upper_minus_lower_unfold + 1e-10)
+ tao, pi = tao.clamp(min=0.), pi.clamp(min=0.)
+ tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos)
+ new_upper_d = pi / (pi + tao + 1e-10)
+
+ assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos"
+
+ unstable_upper_bound_index = unstable_or_cut_index.logical_and((-last_uA.patches) < 0)
+ new_upper_d = new_upper_d * unstable_upper_bound_index.float() + \
+ upper_d * (1. - unstable_upper_bound_index.float())
+
+ uA, ubias = _bound_oneside(last_uA, new_upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, lower_b, start_node, patch_size)
+ if last_lA is None: lA, lbias = None, 0.
+
+ if unstable_or_cut_index.sum() > 0:
+ uC = -x_upper_unfold * nu_hat_pos
+ lC = -x_lower_unfold * nu_hat_pos
+ ubias = -(pi * x_lower_unfold)
+ ubias[beta_mm_coeffs[1] <= uC] = 0.
+ ubias[beta_mm_coeffs[1] >= lC] = beta_mm_coeffs[1][beta_mm_coeffs[1] >= lC].to(ubias)
+ # ubias: unstable, batch, in_C, H, W (same as patches last_uA) => ubias: (unstable, batch)
+ ubias = (ubias * unstable_or_cut_index.float()).view(ubias.shape[0], ubias.shape[1], -1).sum(-1)
+ else:
+ # dense
+ if last_lA is not None:
+ # #####################
+ # # C is nu_hat_pos
+ # # last_lA: (spec unstable, batch, current flattened nodes (current_c*current_h*current_w))
+ # nu_hat_pos = last_lA.clamp(max=0.).abs()
+ # # pi, tao: spec_unstable, batch, current layer shape (same as last_lA)
+ # tao = (-x.lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[0]) / (x.upper.unsqueeze(0) - x.lower.unsqueeze(0) + 1e-10)
+ # pi = (x.upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[0]) / (x.upper.unsqueeze(0) - x.lower.unsqueeze(0) + 1e-10)
+ # tao, pi = tao.clamp(min=0.), pi.clamp(min=0.)
+ # tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos)
+ # new_upper_d = pi / (pi + tao + 1e-10)
+
+ # assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos"
+
+ # # need to customize the upper bound slope and lbias for (1) unstable relus and
+ # # (2) relus that are used with upper boundary relaxation
+ # # original upper bound slope is u/(u-l) also equal to pi/(pi+tao) if no beta_mm_coeffs[0]
+ # # now the upper bound slope should be pi/(p+tao) updated with beta_mm_coeffs[0]
+ # unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(last_lA < 0)
+ # # conv layer:
+ # # upper_d: 1, batch, current_c, current_w, current_h
+ # # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current_c, current_w, current_h
+ # # dense layer:
+ # # upper_d: 1, batch, current flattened nodes
+ # # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current flattened nodes
+ # new_upper_d = new_upper_d * unstable_upper_bound_index.float() +\
+ # upper_d * (1. - unstable_upper_bound_index.float())
+ # #####################
+
+ nu_hat_pos, tao, pi, new_upper_d, unstable_upper_bound_index = self.jit_arelu_lA(last_lA, x.lower, x.upper, beta_mm_coeffs, unstable_or_cut_index, upper_d)
+
+ if last_uA is None: uA, ubias = None, 0.
+ lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, new_upper_d, lower_b, upper_b, start_node, patch_size)
+
+ # if unstable_or_cut_index.sum() == 0: assert (lbias == 0).all(), "lbias should be 0 if no unstable relus"
+
+ # #####################
+ # # if no unstable, following bias should always be 0
+ # if unstable_or_cut_index.sum() > 0:
+ # # update lbias with new form, only contribued by unstable relus
+ # uC = -x.upper.unsqueeze(0) * nu_hat_pos
+ # lC = -x.lower.unsqueeze(0) * nu_hat_pos
+ # # lbias: (spec unstable, batch, current flattened nodes) same as lA
+ # lbias = (pi * x.lower.unsqueeze(0))
+ # lbias[beta_mm_coeffs[0] <= uC] = 0.
+ # lbias[beta_mm_coeffs[0] >= lC] = -beta_mm_coeffs[0][beta_mm_coeffs[0] >= lC].to(lbias)
+ # # final lbias: (spec unstable, batch)
+ # lbias = (lbias * unstable_or_cut_index.unsqueeze(0).float()).view(lbias.shape[0], lbias.shape[1], -1).sum(-1)
+ # #####################
+ lbias = self.jit_arelu_lbias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, x.lower, x.upper, lbias, pi, tao)
+
+ if last_uA is not None:
+ # # C is nu_hat_pos
+ # nu_hat_pos = (-last_uA).clamp(max=0.).abs()
+ # tao = (- x.lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[1]) / (x.upper.unsqueeze(0) - x.lower.unsqueeze(0) + 1e-10)
+ # pi = (x.upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[1]) / (x.upper.unsqueeze(0) - x.lower.unsqueeze(0) + 1e-10)
+ # tao, pi = tao.clamp(min=0.), pi.clamp(min=0.)
+ # tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos)
+ # new_upper_d = pi / (pi + tao + 1e-10)
+
+ # assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos"
+
+ # # unstable_or_cut_index = self.I.logical_or(self.arelu_coeffs.sum(0).view(self.I.shape) != 0)
+ # unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(-last_uA < 0)
+ # new_upper_d = new_upper_d * unstable_upper_bound_index.float() +\
+ # upper_d * (1. - unstable_upper_bound_index.float())
+ nu_hat_pos, tao, pi, new_upper_d, unstable_upper_bound_index = self.jit_arelu_uA(last_uA, x.lower, x.upper, beta_mm_coeffs, unstable_or_cut_index, upper_d)
+
+ # one can test uA by optimize -obj which should have the same obj value
+ uA, ubias = _bound_oneside(last_uA, new_upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, lower_b, start_node, patch_size)
+ if last_lA is None: lA, lbias = None, 0.
+
+ # if unstable_or_cut_index.sum() == 0: assert ubias == 0, "ubias should be 0 if no unstable relus"
+
+ # if unstable_or_cut_index.sum() > 0:
+ # uC = -x.upper.unsqueeze(0) * nu_hat_pos
+ # lC = -x.lower.unsqueeze(0) * nu_hat_pos
+ # ubias = -(pi * x.lower.unsqueeze(0))
+ # ubias[beta_mm_coeffs[1] <= uC] = 0.
+ # ubias[beta_mm_coeffs[1] >= lC] = beta_mm_coeffs[1][beta_mm_coeffs[1] >= lC].to(ubias)
+ # ubias = (ubias * unstable_or_cut_index.unsqueeze(0).float()).view(ubias.shape[0], ubias.shape[1], -1).sum(-1)
+
+ ubias = self.jit_arelu_ubias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, x.lower, x.upper, ubias, pi, tao)
+
+ return lA, uA, lbias, ubias
+
+ def input_cut(self, start_node, lA, uA, current_layer_shape, unstable_idx=None, batch_mask=None):
+ # propagate input neuron in cut constraints through relu layer
+ active_cuts = self.active_cuts[start_node.name]
+ if self.x_coeffs is None or active_cuts.size(0) == 0:
+ return lA, uA
+
+ if type(lA) is Patches:
+ A = lA if lA is not None else uA
+ self.patch_trick(start_node, "input", A, current_layer_shape)
+
+ general_beta = self.select_active_general_beta(start_node, unstable_idx)
+ x_coeffs = self.x_coeffs[active_cuts]
+ if batch_mask is not None:
+ general_beta = general_beta[:, :, batch_mask]
+ # general_beta: (2(0 lower, 1 upper), spec, batch, num_constrs)
+ # x_coeffs: (num_constrs, flattened input dims)
+ # beta_bias: (2(0 lower, 1 upper), batch, spec)
+ lA, uA = self.general_beta_coeffs_addmm_to_A(lA, uA, general_beta, x_coeffs, current_layer_shape)
+ return lA, uA
+
+ def bias_cut(self, start_node, lb, ub, unstable_idx=None, batch_mask=None):
+ active_cuts = self.active_cuts[start_node.name]
+ if self.cut_bias is None or active_cuts.size(0) == 0:
+ return lb, ub
+ bias_coeffs = self.cut_bias[active_cuts]
+ general_beta = self.select_active_general_beta(start_node, unstable_idx)
+ if batch_mask is not None:
+ general_beta = general_beta[:, :, batch_mask]
+ # add bias for the bias term of general cut
+ # general_beta: (2(0 lower, 1 upper), spec, batch, num_constrs)
+ # bias_coeffs: (num_constrs,)
+ # beta_bias: (2(0 lower, 1 upper), batch, spec)
+ beta_bias = torch.einsum('sihj,j->shi', general_beta, bias_coeffs)
+ lb = lb + beta_bias[0] if lb is not None else None
+ ub = ub - beta_bias[1] if ub is not None else None
+ return lb, ub
+
+
+# Choose upper or lower bounds based on the sign of last_A
+# this is a copy from activation.py
+def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg, start_node, patch_size):
+ if last_A is None:
+ return None, 0
+ if type(last_A) == Tensor:
+ A, bias = multiply_by_A_signs(last_A, d_pos, d_neg, b_pos, b_neg, contiguous=True)
+ return A, bias
+ elif type(last_A) == Patches:
+ # if last_A is not an identity matrix
+ assert last_A.identity == 0
+ if last_A.identity == 0:
+ # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension.
+ # or (unstable_size, batch_size, in_c, H, W) when it is sparse.
+ patches = last_A.patches
+ patches_shape = patches.shape
+ if len(patches_shape) == 6:
+ patches = patches.view(*patches_shape[:2], -1, *patches_shape[-2:])
+ if d_pos is not None:
+ d_pos = d_pos.view(*patches_shape[:2], -1, *patches_shape[-2:])
+ if d_neg is not None:
+ d_neg = d_neg.view(*patches_shape[:2], -1, *patches_shape[-2:])
+ if b_pos is not None:
+ b_pos = b_pos.view(*patches_shape[:2], -1, *patches_shape[-2:])
+ if b_neg is not None:
+ b_neg = b_neg.view(*patches_shape[:2], -1, *patches_shape[-2:])
+ A_prod, bias = multiply_by_A_signs(patches, d_pos, d_neg, b_pos, b_neg)
+ # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse.
+ # For sparse patches the return bias size is (unstable_size, batch).
+ # For regular patches the return bias size is (spec, batch, out_h, out_w).
+ if len(patches_shape) == 6:
+ A_prod = A_prod.view(*patches_shape)
+ # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters.
+ if start_node is not None:
+ if last_A.unstable_idx is not None:
+ # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w).
+ patch_size[start_node.name] = [last_A.output_shape[1], A_prod.size(1), last_A.output_shape[2], last_A.output_shape[3], A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)]
+ else:
+ # Regular patches.
+ patch_size[start_node.name] = A_prod.size()
+ return Patches(A_prod, last_A.stride, last_A.padding, A_prod.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape), bias
+
+
+# In patches mode, we need to unfold lower and upper slopes. In matrix mode we simply return.
+# this is a copy from activation.py
+def _maybe_unfold(d_tensor, last_A):
+ # d_tensor (out_c, current_c, current_h, current_w): out_c shared the same alpha for spec layer
+ if d_tensor is None:
+ return None
+ # if mode == "matrix" or d_tensor is None or last_A is None:
+ if type(last_A) is not Patches or d_tensor is None or last_A is None:
+ return d_tensor
+ # Input are slopes with shape (spec, batch, input_c, input_h, input_w)
+ # Here spec is the same as out_c.
+ # assert d_tensor.ndim == 5
+ origin_d_shape = d_tensor.shape
+ if d_tensor.ndim == 6:
+ d_tensor = d_tensor.view(*origin_d_shape[:2], -1, *origin_d_shape[-2:])
+ d_shape = d_tensor.size()
+ # Reshape to 4-D tensor to unfold.
+ d_tensor = d_tensor.view(-1, *d_tensor.shape[-3:])
+ # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W).
+ d_unfolded = inplace_unfold(d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding)
+ # Reshape to (spec, batch, out_h, out_w, in_c, H, W); here spec_size is out_c.
+ d_unfolded_r = d_unfolded.view(*d_shape[:-3], *d_unfolded.shape[1:])
+ if last_A.unstable_idx is not None:
+ if d_unfolded_r.size(0) == 1:
+ if len(last_A.unstable_idx) == 3:
+ # Broadcast the spec shape, so only need to select the reset dimensions.
+ # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W).
+ d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)
+ d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ elif len(last_A.unstable_idx) == 4:
+ # [spec, batch, output_h, output_w, input_c, H, W]
+ # to [output_h, output_w, batch, in_c, H, W]
+ d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)
+ d_unfolded_r = d_unfolded_r[last_A.unstable_idx[2], last_A.unstable_idx[3]]
+ else:
+ raise NotImplementedError()
+ # output shape: (unstable_size, batch, in_c, H, W).
+ else:
+ d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W).
+ # For regular patches, the shape after unfold is (spec, batch, out_h, out_w, in_c, H, W).
+ if d_unfolded_r.ndim != last_A.patches.ndim:
+ d_unfolded_r = d_unfolded_r.unsqueeze(2).unsqueeze(-4)
+ return d_unfolded_r
diff --git a/auto_LiRPA/operators/dropout.py b/auto_LiRPA/operators/dropout.py
index 97c803e..d4601a6 100644
--- a/auto_LiRPA/operators/dropout.py
+++ b/auto_LiRPA/operators/dropout.py
@@ -1,39 +1,69 @@
from .base import *
class BoundDropout(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.dropout = nn.Dropout(p=attr['ratio'])
- self.scale = 1 / (1 - attr['ratio'])
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ if 'ratio' in attr:
+ self.ratio = attr['ratio']
+ self.dynamic = False
+ else:
+ self.ratio = None
+ self.dynamic = True
+ self.clear()
+
+ def clear(self):
+ self.mask = None
+
+ def forward(self, *inputs):
+ x = inputs[0]
+ if not self.training:
+ return x
+ if self.dynamic:
+ # Inputs: data, ratio (optional), training_mode (optional)
+ # We assume ratio must exist in the inputs.
+ # We ignore training_mode, but will use self.training which can be
+ # changed after BoundedModule is built.
+ assert inputs[1].dtype == torch.float32
+ self.ratio = inputs[1]
+ if self.ratio >= 1:
+ raise ValueError('Ratio in dropout should be less than 1')
+ self.mask = torch.rand(x.shape) > self.ratio
+ return x * self.mask / (1 - self.ratio)
- @Bound.save_io_shape
- def forward(self, x):
- res = self.dropout(x)
- self.mask = res == 0
- return res
+ def _check_forward(self):
+ """ If in the training mode, a forward pass should have been called."""
+ if self.training and self.mask is None:
+ raise RuntimeError('For a model with dropout in the training mode, '\
+ 'a clean forward pass must be called before bound computation')
- def bound_backward(self, last_lA, last_uA, x):
+ def bound_backward(self, last_lA, last_uA, *args):
+ empty_A = [(None, None)] * (len(args) -1)
+ if not self.training:
+ return [(last_lA, last_uA), *empty_A], 0, 0
+ self._check_forward()
def _bound_oneside(last_A):
if last_A is None:
return None
- return torch.where(self.mask.unsqueeze(0), torch.tensor(0).to(last_A), last_A * self.scale)
+ return last_A * self.mask / (1 - self.ratio)
lA = _bound_oneside(last_lA)
uA = _bound_oneside(last_uA)
- return [(lA, uA)], 0, 0
+ return [(lA, uA), *empty_A], 0, 0
- def bound_forward(self, dim_in, x):
- assert (torch.min(self.mask) >= 0)
- lw = x.lw * self.mask.unsqueeze(1)
- lb = x.lb * self.mask
- uw = x.uw * self.mask.unsqueeze(1)
- ub = x.ub * self.mask
+ def bound_forward(self, dim_in, x, *args):
+ if not self.training:
+ return x
+ self._check_forward()
+ lw = x.lw * self.mask.unsqueeze(1) / (1 - self.ratio)
+ lb = x.lb * self.mask / (1 - self.ratio)
+ uw = x.uw * self.mask.unsqueeze(1) / (1 - self.ratio)
+ ub = x.ub * self.mask / (1 - self.ratio)
return LinearBound(lw, lb, uw, ub)
def interval_propagate(self, *v):
- h_L, h_U = v[0]
if not self.training:
- return h_L, h_U
- else:
- lower = torch.where(self.mask, torch.tensor(0).to(h_L), h_L * self.scale)
- upper = torch.where(self.mask, torch.tensor(0).to(h_U), h_U * self.scale)
- return lower, upper
\ No newline at end of file
+ return v[0]
+ self._check_forward()
+ h_L, h_U = v[0]
+ lower = h_L * self.mask / (1 - self.ratio)
+ upper = h_U * self.mask / (1 - self.ratio)
+ return lower, upper
\ No newline at end of file
diff --git a/auto_LiRPA/operators/dtype.py b/auto_LiRPA/operators/dtype.py
index 1ed293e..d5dfa38 100644
--- a/auto_LiRPA/operators/dtype.py
+++ b/auto_LiRPA/operators/dtype.py
@@ -1,8 +1,9 @@
from .base import *
+from ..utils import Patches
class BoundCast(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.to = attr['to']
self.data_types = [
None, torch.float, torch.uint8, torch.int8,
@@ -14,17 +15,25 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio
assert self.type is not None
self.use_default_ibp = True
- @Bound.save_io_shape
def forward(self, x):
self.type_in = x.dtype
return x.to(self.type)
def bound_backward(self, last_lA, last_uA, x):
- lA = last_lA.to(self.type_in) if last_lA is not None else None
- uA = last_uA.to(self.type_in) if last_uA is not None else None
+ if type(last_lA) == Tensor or type(last_lA) == Tensor:
+ lA = last_lA.to(self.type_in) if last_lA is not None else None
+ uA = last_uA.to(self.type_in) if last_uA is not None else None
+ else:
+ if last_lA is not None:
+ lA = Patches(last_lA.patches.to(self.type_in), last_lA.stride, last_lA.padding, last_lA.shape, last_lA.identity, last_lA.unstable_idx, last_lA.output_shape)
+ if last_uA is not None:
+ uA = Patches(last_uA.patches.to(self.type_in), last_uA.stride, last_uA.padding, last_uA.shape, last_uA.identity, last_uA.unstable_idx, last_uA.output_shape)
return [(lA, uA)], 0, 0
def bound_forward(self, dim_in, x):
return LinearBound(
x.lw.to(self.type), x.lb.to(self.type),
x.uw.to(self.type), x.ub.to(self.type))
+
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ self.solver_vars = self.forward(v[0])
diff --git a/auto_LiRPA/operators/leaf.py b/auto_LiRPA/operators/leaf.py
index 150e7f0..36c66b2 100644
--- a/auto_LiRPA/operators/leaf.py
+++ b/auto_LiRPA/operators/leaf.py
@@ -1,14 +1,17 @@
-""" Leaf nodes (indepedent nodes in the auto_LiRPA paper),
-including input, parameter, buffer, etc."""
+""" Leaf nodes (indepedent nodes in the auto_LiRPA paper).
+
+Including input, parameter, buffer, etc."""
from .base import *
class BoundInput(Bound):
- def __init__(self, input_name, name, ori_name, value, perturbation=None):
- super().__init__(input_name, name, ori_name)
+ def __init__(self, ori_name, value, perturbation=None, input_index=None):
+ super().__init__()
+ self.ori_name = ori_name
self.value = value
self.perturbation = perturbation
self.from_input = True
+ self.input_index = input_index
def __setattr__(self, key, value):
super().__setattr__(key, value)
@@ -103,7 +106,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
prefix (str): the prefix for parameters and buffers used in this
module
"""
- for name, param in self._parameters.items():
+ for param in self._parameters.values():
if param is not None:
if len(prefix.split('.')) == 2:
destination[self.ori_name] = param if keep_vars else param.detach()
@@ -111,7 +114,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
# change parameters' name to self.ori_name when calling state_dict()
destination[
'.'.join(prefix.split('.')[:-2]) + '.' + self.ori_name] = param if keep_vars else param.detach()
- for name, buf in self._buffers.items():
+ for buf in self._buffers.values():
if buf is not None:
if len(prefix.split('.')) == 2:
destination[self.ori_name] = buf if keep_vars else buf.detach()
@@ -120,7 +123,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
destination[
'.'.join(prefix.split('.')[:-2]) + '.' + self.ori_name] = buf if keep_vars else buf.detach()
- @Bound.save_io_shape
def forward(self):
return self.value
@@ -144,17 +146,15 @@ def infer_batch_dim(self, batch_size, *x):
class BoundParams(BoundInput):
- def __init__(self, input_name, name, ori_name, value, perturbation=None):
- super().__init__(input_name, name, ori_name, None, perturbation)
+ def __init__(self, ori_name, value, perturbation=None):
+ super().__init__(ori_name, None, perturbation)
self.register_parameter('param', value)
self.from_input = False
self.initializing = False
- """Override register_parameter() hook to register only needed parameters."""
-
def register_parameter(self, name, param):
+ """Override register_parameter() hook to register only needed parameters."""
if name == 'param':
- # self._parameters[name] = param # cannot contain '.' in name, it will cause error when loading state_dict
return super().register_parameter(name, param)
else:
# Just register it as a normal property of class.
@@ -163,22 +163,20 @@ def register_parameter(self, name, param):
def init(self, initializing=False):
self.initializing = initializing
- @Bound.save_io_shape
def forward(self):
if self.initializing:
- return self.param_init
+ return self.param_init.requires_grad_(self.training)
else:
- return self.param
+ return self.param.requires_grad_(self.training)
def infer_batch_dim(self, batch_size, *x):
return -1
class BoundBuffers(BoundInput):
- def __init__(self, input_name, name, ori_name, value, perturbation=None):
- super().__init__(input_name, name, ori_name, None, perturbation)
+ def __init__(self, ori_name, value, perturbation=None):
+ super().__init__(ori_name, None, perturbation)
self.register_buffer('buffer', value.clone().detach())
- @Bound.save_io_shape
def forward(self):
- return self.buffer
\ No newline at end of file
+ return self.buffer
diff --git a/auto_LiRPA/operators/linear.py b/auto_LiRPA/operators/linear.py
index 7aa8ed9..8ce3a8b 100644
--- a/auto_LiRPA/operators/linear.py
+++ b/auto_LiRPA/operators/linear.py
@@ -1,9 +1,13 @@
""" Linear (possibly with weight perturbation) or Dot product layers """
from .base import *
from .bivariate import BoundMul
+from ..patches import Patches
+from .solver_utils import grb
+from torch import Tensor
+from ..patches import inplace_unfold
class BoundLinear(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
+ def __init__(self, attr, inputs, output_index, options):
# Gemm:
# A = A if transA == 0 else A.T
# B = B if transB == 0 else B.T
@@ -11,7 +15,7 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio
# Y = alpha * np.dot(A, B) + beta * C
# return Y
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ super().__init__(attr, inputs, output_index, options)
# Defaults in ONNX
self.transA = 0
@@ -28,19 +32,18 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio
"""Handle tranpose and linear coefficients."""
def _preprocess(self, a, b, c=None):
- if self.transA and isinstance(a, torch.Tensor):
+ if self.transA and isinstance(a, Tensor):
a = a.transpose(-2,-1)
if self.alpha != 1.0:
a = self.alpha * a
- if not self.transB and isinstance(b, torch.Tensor):
+ if not self.transB and isinstance(b, Tensor):
# our code assumes B is transposed (common case), so we transpose B only when it is not transposed in gemm.
- b = b.transpose(-2,-1)
+ b = b.transpose(-2, -1)
if c is not None:
if self.beta != 1.0:
c = self.beta * c
return a, b, c
- @Bound.save_io_shape
def forward(self, x, w, b=None):
x, w, b = self._preprocess(x, w, b)
self.input_shape = self.x_shape = x.shape
@@ -61,8 +64,8 @@ def onehot_mult(self, weight, bias, C, batch_size):
if C.index.ndim == 2:
# Shape is [spec, batch]
- index = C.index.transpose(0,1)
- coeffs = C.coeffs.transpose(0,1)
+ index = C.index.transpose(0, 1)
+ coeffs = C.coeffs.transpose(0, 1)
else:
index = C.index
coeffs = C.coeffs
@@ -90,13 +93,12 @@ def onehot_mult(self, weight, bias, C, batch_size):
new_bias = new_bias.transpose(0, 1)
return new_weight, new_bias
-
def bound_backward(self, last_lA, last_uA, *x):
assert len(x) == 2 or len(x) == 3
has_bias = len(x) == 3
# x[0]: input node, x[1]: weight, x[2]: bias
- input_lb = [xi.lower if hasattr(xi, 'lower') else None for xi in x]
- input_ub = [xi.upper if hasattr(xi, 'upper') else None for xi in x]
+ input_lb = [getattr(xi, 'lower', None) for xi in x]
+ input_ub = [getattr(xi, 'upper', None) for xi in x]
# transpose and scale each term if necessary.
input_lb = self._preprocess(*input_lb)
input_ub = self._preprocess(*input_ub)
@@ -106,31 +108,83 @@ def bound_backward(self, last_lA, last_uA, *x):
# Case #1: No weight/bias perturbation, only perturbation on input.
if not self.is_input_perturbed(1) and (not has_bias or not self.is_input_perturbed(2)):
+ weight = input_lb[1]
+ bias = input_lb[2] if has_bias else None
# If last_lA and last_uA are indentity matrices.
- if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC):
+ if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC): # FIXME (12/28): we should check last_lA and last_uA separately. Same applies to the weight perturbed, bias perturbed settings.
# Use this layer's W as the next bound matrices. Duplicate the batch dimension. Other dimensions are kept 1.
# Not perturbed, so we can use either lower or upper.
- lA_x = uA_x = input_lb[1].unsqueeze(1).repeat([1, batch_size] + [1] * (input_lb[1].ndim - 1))
- # Bias will be directly added to output.
+ assert last_lA.shape == last_uA.shape
+ shape_others = prod(last_lA.shape[2:-1])
+ A_identity = torch.eye(shape_others).to(weight).view(shape_others, 1, 1, shape_others, 1)
+ assert last_lA.shape[0] == weight.size(0) * shape_others
+ w = weight.view(1, weight.size(0), *[1] * (len(last_lA.shape) - 2), weight.size(1))
+ w = w * A_identity
+
+ # expand the batch_size dim
+ lA_x = uA_x = w.view(last_lA.shape[0], 1, *last_lA.shape[2:-1], weight.size(1)).expand(last_lA.shape[0], *last_lA.shape[1:-1], weight.size(1))
if has_bias:
- lbias = ubias = input_lb[2].unsqueeze(1).repeat(1, batch_size)
+ lbias = ubias = bias.unsqueeze(1).repeat(1, batch_size)
elif isinstance(last_lA, OneHotC) or isinstance(last_uA, OneHotC):
# We need to select several rows from the weight matrix (its shape is output_size * input_size).
- lA_x, lbias = self.onehot_mult(input_lb[1], input_lb[2] if has_bias else None, last_lA, batch_size)
+ lA_x, lbias = self.onehot_mult(weight, bias, last_lA, batch_size)
if last_lA is last_uA:
uA_x = lA_x
ubias = lbias
else:
- uA_x, ubias = self.onehot_mult(input_lb[1], input_lb[2] if has_bias else None, last_uA, batch_size)
+ uA_x, ubias = self.onehot_mult(weight, bias, last_uA, batch_size)
else:
def _bound_oneside(last_A):
if last_A is None:
return None, 0
- # Just multiply this layer's weight into bound matrices, and produce biases.
- next_A = last_A.to(input_lb[1]).matmul(input_lb[1])
- sum_bias = (last_A.to(input_lb[2]).matmul(input_lb[2])
- if has_bias else 0.0)
- return next_A, sum_bias
+ if isinstance(last_A, torch.Tensor):
+ # Matrix mode.
+ # Just multiply this layer's weight into bound matrices, and produce biases.
+ next_A = last_A.to(weight).matmul(weight)
+ sum_bias = (last_A.to(bias).matmul(bias)
+ if has_bias else 0.0)
+ elif isinstance(last_A, Patches):
+ # Patches mode. After propagating through this layer, it will become a matrix.
+ # Reshape the weight matrix as a conv image.
+ # Weight was in (linear_output_shape, linear_input_shape)
+ # Reshape it to (linear_input_shape, c, h, w)
+ reshaped_weight = weight.transpose(0,1).view(-1, *last_A.input_shape[1:])
+ # After unfolding the shape is (linear_input_shape, output_h, output_w, in_c, patch_h, patch_w)
+ unfolded_weight = inplace_unfold(
+ reshaped_weight,
+ kernel_size=last_A.patches.shape[-2:],
+ stride=last_A.stride, padding=last_A.padding,
+ inserted_zeros=last_A.inserted_zeros,
+ output_padding=last_A.output_padding)
+ if has_bias:
+ # Do the same for the bias.
+ reshaped_bias = bias.view(*last_A.input_shape[1:]).unsqueeze(0)
+ # After unfolding the bias shape is (1, output_h, output_w, in_c, patch_h, patch_w)
+ unfolded_bias = inplace_unfold(reshaped_bias, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding)
+ if last_A.unstable_idx is not None:
+ # Reshape our weight to (output_h, output_w, 1, in_c, patch_h, patch_w, linear_input_shape), 1 is the inserted batch dim.
+ unfolded_weight_r = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(2)
+ # for sparse patches the shape is (unstable_size, batch, in_c, patch_h, patch_w). Batch size is 1 so no need to select here.
+ # We select in the (output_h, out_w) dimension.
+ selected_weight = unfolded_weight_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ next_A = torch.einsum('sbchw,sbchwi->sbi', last_A.patches, selected_weight)
+ if has_bias:
+ # Reshape our bias to (output_h, output_w, 1, in_c, patch_h, patch_w). We already have the batch dim.
+ unfolded_bias_r = unfolded_bias.permute(1, 2, 0, 3, 4, 5)
+ selected_bias = unfolded_bias_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ sum_bias = torch.einsum('sbchw,sbchw->sb', last_A.patches, selected_bias)
+ else:
+ # Reshape our weight to (1, 1, output_h, output_w, in_c, patch_h, patch_w, linear_input_shape), 1 is the spec and batch.
+ selected_weight = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(0).unsqueeze(0)
+ next_A_r = torch.einsum('sbpqchw,sbpqchwi->spqbi', last_A.patches, selected_weight)
+ # We return a matrix with flattened spec dimension (corresponding to out_c * out_h * out_w).
+ next_A = next_A_r.reshape(-1, next_A_r.size(-2), next_A_r.size(-1))
+ if has_bias:
+ # Reshape our bias to (1, 1, output_h, output_w, in_c, patch_h, patch_w)
+ selected_bias = unfolded_bias.unsqueeze(0)
+ sum_bias_r = torch.einsum('sbpqchw,sbpqchw->spqb', last_A.patches, selected_bias)
+ sum_bias = sum_bias_r.reshape(-1, sum_bias_r.size(-1))
+ return next_A, sum_bias if has_bias else 0.0
lA_x, lbias = _bound_oneside(last_lA)
uA_x, ubias = _bound_oneside(last_uA)
@@ -176,30 +230,32 @@ def _bound_oneside(last_A):
def _reshape(self, x_l, x_u, y_l, y_u):
x_shape, y_shape = self.input_shape, self.y_shape
- # (x_1, x_2, ..., x_{n-1}, y_2, x_n)
+ # (x_1, x_2, ..., x_{n-1}, -1, x_n) # FIXME
x_l = x_l.unsqueeze(-2)
x_u = x_u.unsqueeze(-2)
+ # FIXME merge these two cases
if len(x_shape) == len(y_shape):
- # (x_1, x_2, ..., x_{n-1}, y_n, y_{n-1})
- shape = x_shape[:-1] + (y_shape[-1], y_shape[-2])
+ # (x_1, x_2, ..., -1, y_n, y_{n-1})
y_l = y_l.unsqueeze(-3)
y_u = y_u.unsqueeze(-3)
elif len(y_shape) == 2:
- # (x_1, x_2, ..., x_{n-1}, y_2, y_1)
- shape = x_shape[:-1] + y_shape[1:] + y_shape[:1]
+ # (x_1, x_2, ..., -1, y_2, y_1)
y_l = y_l.reshape(*([1] * (len(x_shape) - 2)), *y_shape).unsqueeze(-3)
y_u = y_u.reshape(*([1] * (len(x_shape) - 2)), *y_shape).unsqueeze(-3)
+ else:
+ raise ValueError(f'Unsupported shapes: x_shape {x_shape}, y_shape {y_shape}')
+
return x_l, x_u, y_l, y_u
def _relax(self, input_lb, input_ub):
return BoundMul.get_bound_mul(*self._reshape(input_lb[0], input_ub[0], input_lb[1], input_ub[1]))
+ # FIXME This is nonlinear. Move to `bivariate.py`.
def bound_backward_with_weight(self, last_lA, last_uA, input_lb, input_ub, x, y):
# Note: x and y are not tranposed or scaled, and we should avoid using them directly.
# Use input_lb and input_ub instead.
alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u = self._relax(input_lb, input_ub)
-
alpha_l, alpha_u = alpha_l.unsqueeze(0), alpha_u.unsqueeze(0)
beta_l, beta_u = beta_l.unsqueeze(0), beta_u.unsqueeze(0)
x_shape, y_shape = input_lb[0].size(), input_lb[1].size()
@@ -216,34 +272,33 @@ def bound_backward_with_weight(self, last_lA, last_uA, input_lb, input_ub, x, y)
def _bound_oneside(last_A, alpha_pos, beta_pos, gamma_pos, alpha_neg, beta_neg, gamma_neg):
if last_A is None:
return None, None, 0
- if isinstance(last_A, eyeC):
- A_x = alpha_pos.squeeze(0).permute(1, 0, 2).repeat(1, last_A.shape[1], 1)
- A_y = beta_pos * torch.eye(last_A.shape[2], device=last_A.device) \
- .view((last_A.shape[2], 1, last_A.shape[2], 1))
- if len(dim_y) != 0:
- A_y = torch.sum(beta_pos, dim=dim_y)
- bias = gamma_pos.transpose(0, 1)
- else:
- # last_uA has size (batch, spec, output)
- last_A_pos = last_A.clamp(min=0).unsqueeze(-1)
- last_A_neg = last_A.clamp(max=0).unsqueeze(-1)
- # alpha_u has size (batch, spec, output, input)
- # uA_x has size (batch, spec, input).
- A_x = (alpha_pos.transpose(-1, -2).matmul(last_A_pos) + \
- alpha_neg.transpose(-1, -2).matmul(last_A_neg)).squeeze(-1)
- # beta_u has size (batch, spec, output, input)
- # uA_y is for weight matrix, with parameter size (output, input)
- # uA_y has size (batch, spec, output, input). This is an element-wise multiplication.
- A_y = last_A_pos * beta_pos + last_A_neg * beta_neg
- if len(dim_y) != 0:
- A_y = torch.sum(A_y, dim=dim_y)
- # last_uA has size (batch, spec, output)
- _last_A_pos = last_A_pos.reshape(last_A.shape[0], last_A.shape[1], -1)
- _last_A_neg = last_A_neg.reshape(last_A.shape[0], last_A.shape[1], -1)
- # gamma_u has size (batch, output, 1)
- # ubias has size (batch, spec, 1)
- bias = _last_A_pos.transpose(0, 1).matmul(gamma_pos).transpose(0, 1) + \
- _last_A_neg.transpose(0, 1).matmul(gamma_neg).transpose(0, 1)
+ if isinstance(last_A, eyeC): # FIXME (12/28): Handle the OneHotC case.
+ #FIXME previous implementation is incorrect
+ # expanding eyeC for now
+ last_A = (torch.eye(last_A.shape[0], device=last_A.device)
+ .view(last_A.shape[0], 1, *last_A.shape[2:]).expand(last_A.shape))
+
+ # last_uA has size (batch, spec, output)
+ last_A_pos = last_A.clamp(min=0).unsqueeze(-1)
+ last_A_neg = last_A.clamp(max=0).unsqueeze(-1)
+ # alpha_u has size (batch, spec, output, input)
+ # uA_x has size (batch, spec, input).
+ A_x = (alpha_pos.transpose(-1, -2).matmul(last_A_pos) + \
+ alpha_neg.transpose(-1, -2).matmul(last_A_neg)).squeeze(-1)
+ # beta_u has size (batch, spec, output, input)
+ # uA_y is for weight matrix, with parameter size (output, input)
+ # uA_y has size (batch, spec, output, input). This is an element-wise multiplication.
+ A_y = last_A_pos * beta_pos + last_A_neg * beta_neg
+ if len(dim_y) != 0:
+ A_y = torch.sum(A_y, dim=dim_y)
+ # last_uA has size (batch, spec, output)
+ _last_A_pos = last_A_pos.reshape(last_A.shape[0], last_A.shape[1], -1)
+ _last_A_neg = last_A_neg.reshape(last_A.shape[0], last_A.shape[1], -1)
+ # gamma_u has size (batch, output, 1)
+ # ubias has size (batch, spec, 1)
+ bias = _last_A_pos.transpose(0, 1).matmul(gamma_pos).transpose(0, 1) + \
+ _last_A_neg.transpose(0, 1).matmul(gamma_neg).transpose(0, 1)
+
bias = bias.squeeze(-1)
return A_x, A_y, bias
@@ -254,24 +309,6 @@ def _bound_oneside(last_A, alpha_pos, beta_pos, gamma_pos, alpha_neg, beta_neg,
@staticmethod
def _propagate_Linf(x, w):
- if Interval.use_relative_bounds(x):
- if len(x.nominal.shape) == 2 and w.ndim == 3:
- nominal = torch.bmm(x.nominal.unsqueeze(1), w.transpose(-1, -2)).squeeze(1)
- lower_offset = (
- torch.bmm(x.lower_offset.unsqueeze(1), w.clamp(min=0).transpose(-1, -2)) +
- torch.bmm(x.upper_offset.unsqueeze(1), w.clamp(max=0).transpose(-1, -2))).squeeze(1)
- upper_offset = (
- torch.bmm(x.lower_offset.unsqueeze(1), w.clamp(max=0).transpose(-1, -2)) +
- torch.bmm(x.upper_offset.unsqueeze(1), w.clamp(min=0).transpose(-1, -2))).squeeze(1)
- else:
- nominal = x.nominal.matmul(w.transpose(-1, -2))
- lower_offset = (
- x.lower_offset.matmul(w.clamp(min=0).transpose(-1, -2)) +
- x.upper_offset.matmul(w.clamp(max=0).transpose(-1, -2)))
- upper_offset = (
- x.lower_offset.matmul(w.clamp(max=0).transpose(-1, -2)) +
- x.upper_offset.matmul(w.clamp(min=0).transpose(-1, -2)))
- return Interval(None, None, nominal, lower_offset, upper_offset)
h_L, h_U = x
mid = (h_L + h_U) / 2
diff = (h_U - h_L) / 2
@@ -288,18 +325,11 @@ def interval_propagate(self, *v, C=None, w=None):
has_bias = self is not None and len(v) == 3
if self is not None:
# This will convert an Interval object to tuple. We need to add perturbation property later.
- if Interval.use_relative_bounds(v[0]):
- v_nominal = self._preprocess(v[0].nominal, v[1].nominal, v[2].nominal)
- v_lower_offset = self._preprocess(v[0].lower_offset, v[1].lower_offset, v[2].lower_offset)
- v_upper_offset = self._preprocess(v[0].upper_offset, v[1].upper_offset, v[2].upper_offset)
- v = [Interval(None, None, bounds[0], bounds[1], bounds[2])
- for bounds in zip(v_nominal, v_lower_offset, v_upper_offset)]
- else:
- v_lb, v_ub = zip(*v)
- v_lb = self._preprocess(*v_lb)
- v_ub = self._preprocess(*v_ub)
- # After preprocess the lower and upper bounds, we make them Intervals again.
- v = [Interval.make_interval(bounds[0], bounds[1], bounds[2]) for bounds in zip(v_lb, v_ub, v)]
+ v_lb, v_ub = zip(*v)
+ v_lb = self._preprocess(*v_lb)
+ v_ub = self._preprocess(*v_ub)
+ # After preprocess the lower and upper bounds, we make them Intervals again.
+ v = [Interval.make_interval(bounds[0], bounds[1], bounds[2]) for bounds in zip(v_lb, v_ub, v)]
if w is None and self is None:
# Use C as the weight, no bias.
w, lb, ub = C, torch.tensor(0., device=C.device), torch.tensor(0., device=C.device)
@@ -311,25 +341,16 @@ def interval_propagate(self, *v, C=None, w=None):
# C matrix merging not supported.
assert C is None
res = self.interval_propagate_with_weight(*v)
- if Interval.use_relative_bounds(res):
- if has_bias:
- raise NotImplementedError
- else:
- return res
+ l, u = res
+ if has_bias:
+ return l + v[2][0], u + v[2][1]
else:
- l, u = res
- if has_bias:
- return l + v[2][0], u + v[2][1]
- else:
- return l, u
+ return l, u
else:
- # Use weight
- if Interval.use_relative_bounds(v[1]):
- w = v[1].nominal
- else:
- w = v[1][0]
+ # Use weight
+ w = v[1][0]
if has_bias:
- lb, ub = (v[2].lower, v[2].upper) if Interval.use_relative_bounds(v[2]) else v[2]
+ lb, ub = v[2]
else:
lb = ub = 0.0
@@ -342,14 +363,7 @@ def interval_propagate(self, *v, C=None, w=None):
norm, eps = Interval.get_perturbation(v[0])[:2]
if norm == np.inf:
interval = BoundLinear._propagate_Linf(v[0], w)
- if isinstance(interval, Interval):
- b_center = (lb + ub) / 2
- interval.nominal += b_center
- interval.lower_offset += lb - b_center
- interval.upper_offset += ub - b_center
- return interval
- else:
- center, deviation = interval
+ center, deviation = interval
elif norm > 0:
# General Lp norm.
norm, eps = Interval.get_perturbation(v[0])
@@ -384,36 +398,7 @@ def interval_propagate(self, *v, C=None, w=None):
def interval_propagate_with_weight(self, *v):
input_norm, input_eps = Interval.get_perturbation(v[0])
- weight_norm, weight_eps = Interval.get_perturbation(v[1])
-
- if Interval.use_relative_bounds(*v):
- assert input_norm == weight_norm == np.inf
- assert self.opt_matmul == 'economic'
-
- x, y = v[0], v[1]
-
- nominal = x.nominal.matmul(y.nominal.transpose(-1, -2))
-
- matmul_offset = torch.matmul(
- torch.max(x.lower_offset.abs(), x.upper_offset.abs()),
- torch.max(y.upper_offset.abs(), y.lower_offset.abs()).transpose(-1, -2))
-
- lower_offset = (
- x.nominal.clamp(min=0).matmul(y.lower_offset.transpose(-1, -2)) +
- x.nominal.clamp(max=0).matmul(y.upper_offset.transpose(-1, -2)) +
- x.lower_offset.matmul(y.nominal.clamp(min=0).transpose(-1, -2)) +
- x.upper_offset.matmul(y.nominal.clamp(max=0).transpose(-1, -2)) - matmul_offset)
-
- upper_offset = (
- x.nominal.clamp(min=0).matmul(y.upper_offset.transpose(-1, -2)) +
- x.nominal.clamp(max=0).matmul(y.lower_offset.transpose(-1, -2)) +
- x.upper_offset.matmul(y.nominal.clamp(min=0).transpose(-1, -2)) +
- x.lower_offset.matmul(y.nominal.clamp(max=0).transpose(-1, -2)) + matmul_offset)
-
- return Interval(None, None, nominal, lower_offset, upper_offset)
-
- self.x_shape = v[0][0].shape
- self.y_shape = v[1][0].shape
+ weight_norm, weight_eps = Interval.get_perturbation(v[1])
if input_norm == np.inf and weight_norm == np.inf:
# A memory-efficient implementation without expanding all the elementary multiplications
@@ -424,9 +409,9 @@ def interval_propagate_with_weight(self, *v):
dx, dy = F.relu(x_u - x_l), F.relu(y_u - y_l)
base = x_l.matmul(y_l)
- mask_xp, mask_xn = (x_l > 0).float(), (x_u < 0).float()
+ mask_xp, mask_xn = (x_l > 0).to(x_l.dtype), (x_u < 0).to(x_u.dtype)
mask_xpn = 1 - mask_xp - mask_xn
- mask_yp, mask_yn = (y_l > 0).float(), (y_u < 0).float()
+ mask_yp, mask_yn = (y_l > 0).to(y_l.dtype), (y_u < 0).to(y_u.dtype)
mask_ypn = 1 - mask_yp - mask_yn
lower, upper = base.clone(), base.clone()
@@ -445,7 +430,7 @@ def interval_propagate_with_weight(self, *v):
x_l, x_u = v[0][0].unsqueeze(-2), v[0][1].unsqueeze(-2)
y_l, y_u = v[1][0].unsqueeze(-3), v[1][1].unsqueeze(-3)
# Reuse the multiplication bounds and sum over results.
- lower, upper = BoundMul.interval_propagate(*[(x_l, x_u), (y_l, y_u)])
+ lower, upper = BoundMul.interval_propagate_both_perturbed(*[(x_l, x_u), (y_l, y_u)])
lower, upper = torch.sum(lower, -1), torch.sum(upper, -1)
return lower, upper
@@ -465,9 +450,46 @@ def interval_propagate_with_weight(self, *v):
raise NotImplementedError(
"Unsupported perturbation combination: data={}, weight={}".format(input_norm, weight_norm))
+ @staticmethod
+ @torch.jit.script
+ def bound_forward_mul(x_lw: Tensor, x_lb: Tensor, x_uw: Tensor, x_ub: Tensor, w: Tensor):
+ w_pos, w_neg = w.clamp(min=0), w.clamp(max=0)
+ lw = x_lw.matmul(w_pos) + x_uw.matmul(w_neg)
+ uw = x_uw.matmul(w_pos) + x_lw.matmul(w_neg)
+ lb = x_lb.matmul(w_pos) + x_ub.matmul(w_neg)
+ ub = x_ub.matmul(w_pos) + x_lb.matmul(w_neg)
+ return lw, lb, uw, ub
+
+ # w: an optional argument which can be utilized by BoundMatMul
+ def bound_dynamic_forward(self, x, w=None, b=None, C=None, max_dim=None, offset=0):
+ assert not self.transA and self.alpha == 1.0 and self.transB and self.beta == 1.0
+ assert not self.is_input_perturbed(1)
+ assert not self.is_input_perturbed(2)
+
+ weight = w.lb
+ bias = b.lb if b is not None else None
+ if C is not None:
+ weight = C.to(weight).matmul(weight).transpose(-1, -2)
+ if bias is not None:
+ bias = C.to(bias).matmul(bias)
+ lb = x.lb.unsqueeze(1)
+ else:
+ weight = weight.t()
+ lb = x.lb
+
+ w_new = x.lw.matmul(weight)
+ b_new = lb.matmul(weight)
+ if C is not None:
+ b_new = b_new.squeeze(1)
+ if bias is not None:
+ b_new += bias
+
+ return LinearBound(w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim)
+
# w: an optional argument which can be utilized by BoundMatMul
def bound_forward(self, dim_in, x, w=None, b=None, C=None):
has_bias = b is not None
+ #FIXME _preprocess can only be applied to tensors so far but not linear bounds.
x, w, b = self._preprocess(x, w, b)
# Case #1: No weight/bias perturbation, only perturbation on input.
@@ -480,16 +502,15 @@ def bound_forward(self, dim_in, x, w=None, b=None, C=None):
w = C.to(w).matmul(w).transpose(-1, -2)
if b is not None:
b = C.to(b).matmul(b)
- w_pos, w_neg = w.clamp(min=0), w.clamp(max=0)
- lb = (x.lb.unsqueeze(1).matmul(w_pos) + x.ub.unsqueeze(1).matmul(w_neg)).squeeze(1)
- ub = (x.ub.unsqueeze(1).matmul(w_pos) + x.lb.unsqueeze(1).matmul(w_neg)).squeeze(1)
- else:
+ x_lb, x_ub = x.lb.unsqueeze(1), x.ub.unsqueeze(1)
+ else:
w = w.t()
- w_pos, w_neg = w.clamp(min=0), w.clamp(max=0)
- lb = x.lb.matmul(w_pos) + x.ub.matmul(w_neg)
- ub = x.ub.matmul(w_pos) + x.lb.matmul(w_neg)
- lw = x.lw.matmul(w_pos) + x.uw.matmul(w_neg)
- uw = x.uw.matmul(w_pos) + x.lw.matmul(w_neg)
+ x_lb, x_ub = x.lb, x.ub
+ lw, lb, uw, ub = BoundLinear.bound_forward_mul(x.lw, x_lb, x.uw, x_ub, w)
+
+ if C is not None:
+ lb, ub = lb.squeeze(1), ub.squeeze(1)
+
if b is not None:
lb += b
ub += b
@@ -524,7 +545,7 @@ def bound_forward_with_weight(self, dim_in, x, y):
y.lower.unsqueeze(-3),
y.upper.unsqueeze(-3),
)
- res_mul = BoundMul.bound_forward(dim_in, x_unsqueeze, y_unsqueeze)
+ res_mul = BoundMul.bound_forward_both_perturbed(dim_in, x_unsqueeze, y_unsqueeze)
return LinearBound(
res_mul.lw.sum(dim=-1) if res_mul.lw is not None else None,
res_mul.lb.sum(dim=-1),
@@ -532,16 +553,76 @@ def bound_forward_with_weight(self, dim_in, x, y):
res_mul.ub.sum(dim=-1)
)
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ has_bias = self is not None and len(v) == 3
+ # e.g., last layer gurobi vars (1024,)
+ gvars_array = np.array(v[0])
+ # pre_layer_shape (1024,)
+ pre_layer_shape = gvars_array.shape
+ # this layer shape (100,)
+ # if last layer, this layer shape (9,) instead of (10,)!!!
+ this_layer_shape = self.lower.squeeze(0).shape
+ out_lbs = self.lower.squeeze(0).detach().cpu().numpy() if self.lower is not None else None
+ out_ubs = self.upper.squeeze(0).detach().cpu().numpy() if self.upper is not None else None
+
+ # current layer weight (100, 1024)
+ this_layer_weight = v[1]
+ #### make sure if this is correct for per-label operations
+ if C is not None:
+ # merge specification C into last layer weights
+ # only last layer has C not None
+ this_layer_weight = C.squeeze(0).mm(this_layer_weight)
+ # if last layer, this layer weight (9,100) instead of (10,100)!!!
+ this_layer_weight = this_layer_weight.detach().cpu().numpy()
+
+ this_layer_bias = None
+ if has_bias:
+ # current layer bias (100,)
+ this_layer_bias = v[2]
+ if C is not None:
+ this_layer_bias = C.squeeze(0).mm(this_layer_bias.unsqueeze(-1)).view(-1)
+ # if last layer, this layer bias (9,) instead of (10,)!!!
+ this_layer_bias = this_layer_bias.detach().cpu().numpy()
+
+ new_layer_gurobi_vars = []
+
+ for neuron_idx in range(this_layer_shape[0]):
+ out_lb = out_lbs[neuron_idx] if out_lbs is not None else -float('inf')
+ out_ub = out_ubs[neuron_idx] if out_ubs is not None else float('inf')
+
+ lin_expr = 0
+ if has_bias:
+ lin_expr = this_layer_bias[neuron_idx].item()
+ coeffs = this_layer_weight[neuron_idx, :]
+
+ if solver_pkg == 'gurobi':
+ lin_expr += grb.LinExpr(coeffs, v[0])
+ else:
+ # FIXME (01/12/22): This is slow, must be fixed using addRow() or similar.
+ for i in range(len(coeffs)):
+ try:
+ lin_expr += coeffs[i] * v[0][i]
+ except TypeError:
+ lin_expr += coeffs[i] * v[0][i].var
+
+ var = model.addVar(lb=out_lb, ub=out_ub, obj=0,
+ vtype=grb.GRB.CONTINUOUS,
+ name=f'lay{self.name}_{neuron_idx}')
+ model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq')
+ new_layer_gurobi_vars.append(var)
+
+ self.solver_vars = new_layer_gurobi_vars
+ model.update()
+
class BoundMatMul(BoundLinear):
# Reuse most functions from BoundLinear.
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.transA = 0
- self.transB = 1 # MatMul assumes B is transposed.
- self.nonlinear = True
+ self.transB = 0
+ self.requires_input_bounds = [0, 1]
- @Bound.save_io_shape
def forward(self, x, y):
self.x_shape = x.shape
self.y_shape = y.shape
@@ -550,23 +631,14 @@ def forward(self, x, y):
return x.matmul(y)
def interval_propagate(self, *v):
- w_l = v[1][0].transpose(-1, -2)
- w_u = v[1][1].transpose(-1, -2)
- lower, upper = super().interval_propagate(v[0], (w_l, w_u))
- return lower, upper
+ lower, upper = super().interval_propagate(*v)
+ return lower, upper
def bound_backward(self, last_lA, last_uA, *x):
assert len(x) == 2
- # BoundLinear has W transposed.
- x[1].lower = x[1].lower.transpose(-1, -2)
- x[1].upper = x[1].upper.transpose(-1, -2)
results = super().bound_backward(last_lA, last_uA, *x)
- # Transpose input back.
- x[1].lower = x[1].lower.transpose(-1, -2)
- x[1].upper = x[1].upper.transpose(-1, -2)
lA_y = results[0][1][0].transpose(-1, -2) if results[0][1][0] is not None else None
uA_y = results[0][1][1].transpose(-1, -2) if results[0][1][1] is not None else None
- # Transpose result on A.
return [results[0][0], (lA_y, uA_y), results[0][2]], results[1], results[2]
def bound_forward(self, dim_in, x, y):
@@ -579,37 +651,62 @@ def bound_forward(self, dim_in, x, y):
y.upper.transpose(-1, -2) if y.upper is not None else None
))
- def infer_batch_dim(self, batch_size, *x):
- return BoundMul.infer_batch_dim(batch_size, *x)
-
class BoundNeg(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
- @Bound.save_io_shape
def forward(self, x):
return -x
def bound_backward(self, last_lA, last_uA, x):
- return [(-last_lA if last_lA is not None else None,
+ if type(last_lA) == Tensor or type(last_uA) == Tensor:
+ return [(-last_lA if last_lA is not None else None,
-last_uA if last_uA is not None else None)], 0, 0
+ elif type(last_lA) == Patches or type(last_uA) == Patches:
+ if last_lA is not None:
+ lA = Patches(-last_lA.patches, last_lA.stride, last_lA.padding, last_lA.shape, unstable_idx=last_lA.unstable_idx, output_shape=last_lA.output_shape)
+ else:
+ lA = None
+
+ if last_uA is not None:
+ uA = Patches(-last_uA.patches, last_uA.stride, last_uA.padding, last_uA.shape, unstable_idx=last_uA.unstable_idx, output_shape=last_uA.output_shape)
+ else:
+ uA = None
+ return [(lA, uA)], 0, 0
+ else:
+ raise NotImplementedError
def bound_forward(self, dim_in, x):
return LinearBound(-x.uw, -x.ub, -x.lw, -x.lb)
def interval_propagate(self, *v):
- return -v[0][1], -v[0][0]
+ return -v[0][1], -v[0][0]
+
class BoundCumSum(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.use_default_ibp = True
- @Bound.save_io_shape
def forward(self, x, axis):
self.axis = axis
return torch.cumsum(x, axis)
def infer_batch_dim(self, batch_size, *x):
assert self.axis != x[0]
- return x[0]
\ No newline at end of file
+ return x[0]
+
+
+class BoundIdentity(Bound):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.use_default_ibp = True
+
+ def forward(self, x):
+ return x
+
+ def bound_backward(self, last_lA, last_uA, x):
+ return [(last_lA, last_uA)], 0, 0
+
+ def bound_forward(self, dim_in, x):
+ return x
diff --git a/auto_LiRPA/operators/logical.py b/auto_LiRPA/operators/logical.py
index 869a50a..179a7ac 100644
--- a/auto_LiRPA/operators/logical.py
+++ b/auto_LiRPA/operators/logical.py
@@ -3,24 +3,14 @@
class BoundWhere(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
- @Bound.save_io_shape
def forward(self, condition, x, y):
return torch.where(condition.to(torch.bool), x, y)
def interval_propagate(self, *v):
assert not self.is_input_perturbed(0)
-
- if Interval.use_relative_bounds(*v):
- return Interval(
- None, None,
- self.forward(v[0].nominal, v[1].nominal, v[2].nominal),
- self.forward(v[0].nominal, v[1].lower_offset, v[2].lower_offset),
- self.forward(v[0].nominal, v[1].upper_offset, v[2].upper_offset)
- )
-
condition = v[0][0]
return tuple([torch.where(condition, v[1][j], v[2][j]) for j in range(2)])
@@ -42,14 +32,9 @@ def _bound_oneside(last_A):
return [(None, None), (lA_x, uA_x), (lA_y, uA_y)], 0, 0
- def infer_batch_dim(self, batch_size, *x):
- return BoundMul.infer_batch_dim(batch_size, *x[1:])
-
-
class BoundNot(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
- @Bound.save_io_shape
def forward(self, x):
return x.logical_not()
\ No newline at end of file
diff --git a/auto_LiRPA/operators/nonlinear.py b/auto_LiRPA/operators/nonlinear.py
new file mode 100644
index 0000000..4f6bb5c
--- /dev/null
+++ b/auto_LiRPA/operators/nonlinear.py
@@ -0,0 +1,668 @@
+"""Unary nonlinearities other than activation functions."""
+import math
+import torch
+from .activation import BoundActivation, BoundTanh
+from .base import epsilon, LinearBound
+
+class BoundSin(BoundActivation):
+ # Lookup tables shared by all BoundSin classes.
+ xl_lower_tb = None
+ xl_upper_tb = None
+ xu_lower_tb = None
+ xu_upper_tb = None
+ func, d_func = torch.sin, torch.cos
+ n_table_entries = 1001
+
+ @staticmethod
+ def n_crossing(start, end, s):
+ """Check how many times we will encounter value s + k*2*pi within start and end for any integer k."""
+ dtype = start.dtype
+ cycles = torch.floor((end - start) / (2 * math.pi)) # Number of 2pi cycles.
+ # Move s and end to the same 2 * pi cycle as start.
+ dist = torch.floor((s - start) / (2 * math.pi))
+ real_s = s - dist * 2 * math.pi
+ real_end = end - cycles * 2 * math.pi
+ # assert (real_end >= start - 2 ** (-20)).all()
+ return (real_s >= start).to(dtype) * (real_s <= real_end).to(dtype) + cycles
+
+ @staticmethod
+ def get_intersection(start, end, c, theta=0.):
+ """Get the number of intersections between y = sin(x + theta) and y = c between start and end."""
+ # Use arcsine to find the first 2 intersections.
+ crossing1 = torch.arcsin(c) - theta
+ crossing2 = math.pi - crossing1 - 2 * theta # Problematic at exact 1/2 pi, but ok in our case (happens only when lb=ub).
+ return BoundSin.n_crossing(start, end, crossing1) + BoundSin.n_crossing(start, end, crossing2)
+
+ @staticmethod
+ def get_bounding_slope(xl, xu, c, theta=0.):
+ """Find the point between xl and xu such that the tangent line at that point is a lower/upper bound."""
+ dtype = xl.dtype
+ crossing1 = torch.arcsin(c) - theta # output is [-0.5 pi, 0.5 pi] - theta. For cosine, theta=0.5 pi and crossing point is between -pi to 0.
+ crossing2 = math.pi - crossing1 - 2 * theta # output is [0.5pi, 1.5pi] - theta. For cosine, it is between 0 and pi.
+ # Find the crossing point between xl and xu.
+ # First see how xl is away from the [-0.5pi, 1.5pi] range for sine or [-pi, pi] range for cosine.
+ cycles1 = torch.floor((crossing1 - xl) / (2 * math.pi)) * 2 * math.pi
+ # Move the two crossing points to the same cycle as xl.
+ crossing1_moved = crossing1 - cycles1
+ cycles2 = torch.floor((crossing2 - xl) / (2 * math.pi)) * 2 * math.pi
+ crossing2_moved = crossing2 - cycles2
+ # Then check which crossing point is the actual tangent point between xl and xu.
+ crossing1_used = (crossing1_moved >= xl).to(dtype) * (crossing1_moved <= xu).to(dtype)
+ crossing2_used = (crossing2_moved >= xl).to(dtype) * (crossing2_moved <= xu).to(dtype)
+ crossing_point = crossing1_used * crossing1_moved + crossing2_used * crossing2_moved
+ # print(f'c1={crossing1.item():.05f}, c2={crossing2.item():.05f}, cy1={cycles1.item():.05f}, cy2={cycles2.item():.05f}, c1m={crossing1_moved.item():.05f}, c2m={crossing2_moved.item():.05f}, u1={crossing1_used.item()}, u2={crossing2_used.item()}, xl={xl.item():.05f}, xu={xu.item():.05f}')
+ return crossing_point
+
+ @staticmethod
+ def check_bound(tangent_point, x):
+ """Check whether the tangent line at tangent_point is a valid lower/upper bound for x."""
+ # evaluate the value of the tangent line at x and see it is >= 0 or <=0.
+ d = BoundSin.d_func(tangent_point)
+ val = d * (x - tangent_point) + BoundSin.func(tangent_point)
+ # We want a positive margin when finding a lower line, but as close to 0 as possible.
+ # We want a negative margin when finding a upper line, but as close to 0 as possible.
+ margin = BoundSin.func(x) - val
+ return margin
+
+ @staticmethod
+ @torch.no_grad()
+ def get_lower_left_bound(xl, steps=20):
+ """Get a global lower bound given lower bound on x. Return slope and intercept."""
+ dtype = xl.dtype
+ # Constrain xl into the -0.5 pi to 1.5 pi region.
+ cycles = torch.floor((xl + 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi)
+ xl = xl - cycles
+ use_tangent_line = (xl >= math.pi).to(dtype)
+ # Case 1: xl > pi, Lower tangent line is the only possible lower bound.
+ case1_d = BoundSin.d_func(xl)
+ case1_b = BoundSin.func(xl) - case1_d * (xl + cycles)
+ # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region.
+ left = math.pi * torch.ones_like(xl)
+ # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1).
+ right = (1.5 * math.pi) * torch.ones_like(xl)
+ last_right = right.clone()
+ for i in range(steps):
+ mid = (left + right) / 2.
+ margin = BoundSin.check_bound(mid, xl)
+ pos_mask = (margin > 0).to(dtype) # We want to margin > 0 but at small as possible.
+ neg_mask = 1.0 - pos_mask
+ right = mid * pos_mask + right * neg_mask # We have positive margin, reduce right hand side.
+ last_right = mid * pos_mask + last_right * neg_mask # Always sound, since the margin is positive.
+ left = mid * neg_mask + left * pos_mask
+ case2_d = BoundSin.d_func(last_right)
+ case2_b = BoundSin.func(last_right) - case2_d * (last_right + cycles)
+ d = case1_d * use_tangent_line + case2_d * (1. - use_tangent_line)
+ b = case1_b * use_tangent_line + case2_b * (1. - use_tangent_line)
+ # Return slope and bias.
+ return [d, b]
+
+ @staticmethod
+ @torch.no_grad()
+ def get_upper_left_bound(xl, steps=20):
+ dtype = xl.dtype
+ """Get a global upper bound given lower bound on x. Return slope and intercept."""
+ # Constrain xl into the -0.5 pi to 1.5 pi region.
+ cycles = torch.floor((xl - 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi)
+ xl = xl - cycles
+ use_tangent_line = (xl >= 2.0 * math.pi).to(dtype)
+ # Case 1: xl > pi, Lower tangent line is the only possible lower bound.
+ case1_d = BoundSin.d_func(xl)
+ case1_b = BoundSin.func(xl) - case1_d * (xl + cycles)
+ # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region.
+ left = (2.0 * math.pi) * torch.ones_like(xl)
+ # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1).
+ right = (2.5 * math.pi) * torch.ones_like(xl)
+ last_right = right.clone()
+ for i in range(steps):
+ mid = (left + right) / 2.
+ margin = BoundSin.check_bound(mid, xl)
+ pos_mask = (margin > 0).to(dtype) # We want to margin < 0 but at small as possible.
+ neg_mask = 1.0 - pos_mask
+ right = mid * neg_mask + right * pos_mask # We have positive margin, reduce right hand side.
+ last_right = mid * neg_mask + last_right * pos_mask # Always sound, since the margin is positive.
+ left = mid * pos_mask + left * neg_mask
+ case2_d = BoundSin.d_func(last_right)
+ case2_b = BoundSin.func(last_right) - case2_d * (last_right + cycles)
+ d = case1_d * use_tangent_line + case2_d * (1. - use_tangent_line)
+ b = case1_b * use_tangent_line + case2_b * (1. - use_tangent_line)
+ # Return slope and bias.
+ return [d, b]
+
+ @staticmethod
+ @torch.no_grad()
+ def get_lower_right_bound(xu, steps=20):
+ """Get a global lower bound given upper bound on x. Return slope and intercept."""
+ # Constrain xu into the -0.5 pi to 1.5 pi region.
+ cycles = torch.floor((xu + 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi)
+ xu = xu - cycles
+ d, _ = BoundSin.get_lower_left_bound(math.pi - xu, steps)
+ return [-d, BoundSin.func(xu) + d * (xu + cycles)]
+
+ @staticmethod
+ @torch.no_grad()
+ def get_upper_right_bound(xu, steps=20):
+ """Get a global upper bound given upper bound on x. Return slope and intercept."""
+ # Constrain xu into the 0.5 pi to 2.5 pi region.
+ cycles = torch.floor((xu - 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi)
+ xu = xu - cycles
+ d, _ = BoundSin.get_upper_left_bound(3 * math.pi - xu, steps)
+ return [-d, BoundSin.func(xu) + d * (xu + cycles)]
+
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ # Bound limits used by IBP.
+ self.max_point = math.pi / 2
+ self.min_point = math.pi * 3 / 2
+
+ self.all_table_x = torch.linspace(0, 2 * math.pi, BoundSin.n_table_entries, device=self.device)
+ if BoundSin.xl_lower_tb is None:
+ # Generate look-up tables.
+ BoundSin.xl_lower_tb = BoundSin.get_lower_left_bound(self.all_table_x)
+ BoundSin.xl_upper_tb = BoundSin.get_upper_left_bound(self.all_table_x)
+ BoundSin.xu_lower_tb = BoundSin.get_lower_right_bound(self.all_table_x)
+ BoundSin.xu_upper_tb = BoundSin.get_upper_right_bound(self.all_table_x)
+ BoundSin.xl_lower_tb[0], BoundSin.xl_lower_tb[1] = BoundSin.xl_lower_tb[0].to(self.device), BoundSin.xl_lower_tb[1].to(self.device)
+ BoundSin.xl_upper_tb[0], BoundSin.xl_upper_tb[1] = BoundSin.xl_upper_tb[0].to(self.device), BoundSin.xl_upper_tb[1].to(self.device)
+ BoundSin.xu_lower_tb[0], BoundSin.xu_lower_tb[1] = BoundSin.xu_lower_tb[0].to(self.device), BoundSin.xu_lower_tb[1].to(self.device)
+ BoundSin.xu_upper_tb[0], BoundSin.xu_upper_tb[1] = BoundSin.xu_upper_tb[0].to(self.device), BoundSin.xu_upper_tb[1].to(self.device)
+
+ @staticmethod
+ def interpoloate(x, lower_x, upper_x, lower_y, upper_y):
+ # x = torch.clamp(x, min=lower_x, max=upper_x) # For pytorch >= 1.11
+ x = torch.max(torch.min(x, upper_x), lower_x)
+ ratio = (x - lower_x) / (upper_x - lower_x + 1e-10)
+ return lower_y * (1. - ratio) + upper_y * ratio
+
+ def get_bound_tb(self, tb, x):
+ """Find lower or upper bounds from lookup table."""
+ step = 2 * math.pi / (BoundSin.n_table_entries - 1)
+ # Move to 0 to 2 pi region.
+ cycles = torch.floor(x / (2 * math.pi)) * (2 * math.pi)
+ x = torch.clamp(x - cycles, min=0, max=2 * math.pi)
+ # Find the indice within the lookup table from 0 - 2pi.
+ indices = x.div(step).long()
+ # Intepolate the nearest d and b. This has better differentiability.
+ # Another option is to always take the lower/upper side (better soundness).
+ upper_indices = torch.clamp(indices + 1, max=BoundSin.n_table_entries-1)
+ lower_x = self.all_table_x[indices]
+ upper_x = self.all_table_x[upper_indices]
+ # print(indices.item(), lower_x.item(), upper_x.item(), tb[0][indices].item(), tb[0][upper_indices].item())
+ d = self.interpoloate(x, lower_x, upper_x, tb[0][indices], tb[0][upper_indices])
+ b = self.interpoloate(x, lower_x, upper_x, tb[1][indices], tb[1][upper_indices])
+ return d, b - d * cycles
+
+ def forward(self, x):
+ return torch.sin(x)
+
+ def interval_propagate(self, *v):
+ # Check if a point is in [l, u], considering the 2pi period
+ def check_crossing(ll, uu, point):
+ return ((((uu - point) / (2 * math.pi)).floor() - ((ll - point) / (2 * math.pi)).floor()) > 0).to(h_Ls.dtype)
+ h_L, h_U = v[0][0], v[0][1]
+ h_Ls, h_Us = self.forward(h_L), self.forward(h_U)
+ # If crossing pi/2, then max is fixed 1.0
+ max_mask = check_crossing(h_L, h_U, self.max_point)
+ # If crossing pi*3/2, then min is fixed -1.0
+ min_mask = check_crossing(h_L, h_U, self.min_point)
+ ub = torch.max(h_Ls, h_Us)
+ ub = max_mask + (1 - max_mask) * ub
+ lb = torch.min(h_Ls, h_Us)
+ lb = - min_mask + (1 - min_mask) * lb
+ return lb, ub
+
+ def bound_relax_impl(self, lb, ub):
+ dtype = lb.dtype
+ # Case 1: Connect the two points as a line
+ sub = self.func(ub)
+ slb = self.func(lb)
+ mid = (sub + slb) / 2.
+ smid = self.func((ub + lb) / 2)
+ case1_line_slope = (sub - slb) / (ub - lb + 1e-10)
+ case1_line_bias = slb - case1_line_slope * lb
+ gap = smid - mid
+ # Check if there are crossings between the line and the sin function.
+ grad_crossings = self.get_intersection(lb, ub, case1_line_slope, theta=0.5 * math.pi)
+ # If there is no crossing, then we can connect the two points together as a lower/upper bound.
+ use_line = grad_crossings == 1
+ # Connected line is the upper bound.
+ upper_use_line = torch.logical_and(gap < 0, use_line)
+ # Connected line is the lower bound.
+ lower_use_line = torch.logical_and(gap >= 0, use_line)
+ # For the other bound, use the tangent line.
+ case1_tangent_point = self.get_bounding_slope(lb, ub, case1_line_slope, theta=0.5 * math.pi)
+ case1_tangent_slope = case1_line_slope # Use the same slope so far.
+ stangent = self.func(case1_tangent_point)
+ case1_tangent_bias = stangent - case1_tangent_slope * case1_tangent_point
+ # Choose the lower/upper based on gap.
+ case1_lower_slope = lower_use_line * case1_line_slope + upper_use_line * case1_tangent_slope
+ case1_lower_bias = lower_use_line * case1_line_bias + upper_use_line * case1_tangent_bias
+ case1_upper_slope = upper_use_line * case1_line_slope + lower_use_line * case1_tangent_slope
+ case1_upper_bias = upper_use_line * case1_line_bias + lower_use_line * case1_tangent_bias
+
+ # Case 2: we will try the global lower/upper bounds at lb and ub.
+ # For the points and lb and ub, we can construct both lower and upper bounds.
+ left_lower = self.get_bound_tb(BoundSin.xl_lower_tb, lb) # slope, bias.
+ left_upper = self.get_bound_tb(BoundSin.xl_upper_tb, lb)
+ right_lower = self.get_bound_tb(BoundSin.xu_lower_tb, ub)
+ right_upper = self.get_bound_tb(BoundSin.xu_upper_tb, ub)
+ # Determine which lower bound is tighter.
+ left_lower_error = sub - (left_lower[0] * ub + left_lower[1])
+ right_lower_error = slb - (right_lower[0] * lb + right_lower[1])
+ left_upper_error = (left_upper[0] * ub + left_upper[1]) - sub
+ right_upper_error = (right_upper[0] * lb + right_upper[1]) - slb
+ use_left_lower = (left_lower_error < right_lower_error).to(dtype)
+ use_right_lower = 1. - use_left_lower
+ use_left_upper = (left_upper_error < right_upper_error).to(dtype)
+ use_right_upper = 1. - use_left_upper
+ # Choose the slope and bias in this case.
+ case_2_lower_slope = use_left_lower * left_lower[0] + use_right_lower * right_lower[0]
+ case_2_lower_bias = use_left_lower * left_lower[1] + use_right_lower * right_lower[1]
+ case_2_upper_slope = use_left_upper * left_upper[0] + use_right_upper * right_upper[0]
+ case_2_upper_bias = use_left_upper * left_upper[1] + use_right_upper * right_upper[1]
+
+ # Finally, choose between case 1 and case 2.
+ use_line = use_line.to(dtype)
+ not_use_line = 1. - use_line
+ lower_slope = use_line * case1_lower_slope + not_use_line * case_2_lower_slope
+ lower_bias = use_line * case1_lower_bias + not_use_line * case_2_lower_bias
+ upper_slope = use_line * case1_upper_slope + not_use_line * case_2_upper_slope
+ upper_bias = use_line * case1_upper_bias + not_use_line * case_2_upper_bias
+ # print(gap, lower_slope, lower_bias, upper_slope, upper_bias)
+ return lower_slope, lower_bias, upper_slope, upper_bias
+
+ def bound_relax(self, x):
+ lower_slope, lower_bias, upper_slope, upper_bias = self.bound_relax_impl(x.lower, x.upper)
+ self.lw = lower_slope
+ self.lb = lower_bias
+ self.uw = upper_slope
+ self.ub = upper_bias
+
+
+class BoundCos(BoundSin):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.max_point = 0.0
+ self.min_point = math.pi
+
+ def forward(self, x):
+ return torch.cos(x)
+
+ def bound_relax(self, x):
+ # Shift the input by 0.5*pi, and shifting the linear bounds back.
+ lb = x.lower + 0.5 * math.pi
+ ub = x.upper + 0.5 * math.pi
+ lower_slope, lower_bias, upper_slope, upper_bias = self.bound_relax_impl(lb, ub)
+ self.lw = lower_slope
+ self.lb = lower_slope * (0.5 * math.pi) + lower_bias
+ self.uw = upper_slope
+ self.ub = upper_slope * (0.5 * math.pi) + upper_bias
+
+
+class BoundAtan(BoundTanh):
+ def __init__(self, attr, inputs, output_index, options):
+ super(BoundTanh, self).__init__(attr, inputs, output_index, options)
+ self.precompute_relaxation('arctan', torch.arctan, self.darctan)
+ # Alpha dimension is (4, 2, output_shape, batch, *shape) for S-shaped functions.
+ self.alpha_batch_dim = 3
+
+ def forward(self, x):
+ return torch.arctan(x)
+
+ def darctan(self, x):
+ return (x.square() + 1.).reciprocal()
+
+ def bound_relax(self, x):
+ self.bound_relax_impl(x, torch.arctan, self.darctan)
+
+
+class BoundTan(BoundAtan):
+ """
+ The implementation of BoundTan is based on the S-shaped BoundAtan. We use the bounds from its
+ inverse function and directly convert the bounds of the inverse function to bounds of the original
+ function. This trick allows us to quickly implement bounds on inverse functions.
+ """
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+
+ def forward(self, x):
+ return torch.tan(x)
+
+ def _check_bounds(self, lower, upper):
+ # Lower and upper bounds must be within the same [-½π, ½π] region.
+ lower_periods = torch.floor((lower + 0.5 * torch.pi) / torch.pi)
+ upper_periods = torch.floor((upper + 0.5 * torch.pi) / torch.pi)
+ if not torch.allclose(lower_periods, upper_periods):
+ print('Tan preactivation lower bounds:\n', lower)
+ print('Tan preactivation upper bounds:\n', upper)
+ raise ValueError("BoundTan received pre-activation bounds that produce infinity. "
+ "The preactivation bounds are too loose. Try to reduce perturbation region.")
+ # Return the period number for each neuron.
+ # Period is 0 => bounds are within [-½π, ½π],
+ # Period is 1 => bounds are within [-½π + π, ½π + π]
+ # Period is -1 => bounds are within [-½π - π, ½π - π]
+ return lower_periods
+
+ def _init_masks(self, x):
+ # The masks now must consider the periodicity.
+ lower = torch.remainder(x.lower + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi
+ upper = torch.remainder(x.upper + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi
+ self.mask_pos = lower >= 0
+ self.mask_neg = upper <= 0
+ self.mask_both = torch.logical_not(torch.logical_or(self.mask_pos, self.mask_neg))
+
+ def interval_propagate(self, *v):
+ # We need to check if the input lower and upper bounds are within the same period.
+ # Otherwise the bounds become infinity.
+ concrete_lower, concrete_upper = v[0][0], v[0][1]
+ self._check_bounds(concrete_lower, concrete_upper)
+ return super().interval_propagate(*v)
+
+ def bound_relax(self, x):
+ periods = self._check_bounds(x.lower, x.upper)
+ periods = torch.pi * periods
+ # Create a fake x with inversed lower and upper.
+ inverse_x = lambda: None
+ inverse_x.lower = torch.tan(x.lower)
+ inverse_x.upper = torch.tan(x.upper)
+ super().bound_relax(inverse_x)
+ # Lower slope, lower bias, upper slope and upper bias are saved to
+ # self.lw, self.lb, self.uw, self.ub. We need to reverse them.
+ # E.g., y = self.lw * x + self.lb, now becomes x = 1./self.lw * y - self.lb / self.lw
+ # Additionally, we need to add the missing ½π periods.
+ new_upper_slope = 1. / self.lw
+ new_upper_bias = - self.lb / self.lw - periods / self.lw
+ new_lower_slope = 1. / self.uw
+ new_lower_bias = - self.ub / self.uw - periods / self.uw
+ self.lw = new_lower_slope
+ self.lb = new_lower_bias
+ self.uw = new_upper_slope
+ self.ub = new_upper_bias
+
+
+class BoundExp(BoundActivation):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.options = options.get('exp')
+ self.max_input = 0
+
+ def forward(self, x):
+ if self.loss_fusion and self.options != 'no-max-input':
+ self.max_input = torch.max(x, dim=-1, keepdim=True)[0].detach()
+ return torch.exp(x - self.max_input)
+ return torch.exp(x)
+
+ def interval_propagate(self, *v):
+ assert len(v) == 1
+ # unary monotonous functions only
+ h_L, h_U = v[0]
+ if self.loss_fusion and self.options != 'no-max-input':
+ self.max_input = torch.max(h_U, dim=-1, keepdim=True)[0]
+ h_L, h_U = h_L - self.max_input, h_U - self.max_input
+ else:
+ self.max_input = 0
+ return torch.exp(h_L), torch.exp(h_U)
+
+ def bound_forward(self, dim_in, x):
+ m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99)
+
+ exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper)
+
+ kl = exp_m
+ lw = x.lw * kl.unsqueeze(1)
+ lb = kl * (x.lb - m + 1)
+
+ ku = (exp_u - exp_l) / (x.upper - x.lower + epsilon)
+ uw = x.uw * ku.unsqueeze(1)
+ ub = x.ub * ku - ku * x.lower + exp_l
+
+ return LinearBound(lw, lb, uw, ub)
+
+ def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None):
+ # Special case when computing log_softmax (FIXME: find a better solution, this trigger condition is not reliable).
+ if self.loss_fusion and last_lA is None and last_uA is not None and torch.min(
+ last_uA) >= 0 and x.from_input:
+ # Adding an extra bias term to the input. This is equivalent to adding a constant and subtract layer before exp.
+ # Note that we also need to adjust the bias term at the end.
+ if self.options == 'no-detach':
+ self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0]
+ elif self.options != 'no-max-input':
+ self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0].detach()
+ else:
+ self.max_input = 0
+ adjusted_lower = x.lower - self.max_input
+ adjusted_upper = x.upper - self.max_input
+ # relaxation for upper bound only (used in loss fusion)
+ exp_l, exp_u = torch.exp(adjusted_lower), torch.exp(adjusted_upper)
+ k = (exp_u - exp_l) / (adjusted_upper - adjusted_lower + epsilon)
+ if k.requires_grad:
+ k = k.clamp(min=1e-6)
+ uA = last_uA * k.unsqueeze(0)
+ ubias = last_uA * (-adjusted_lower * k + exp_l).unsqueeze(0)
+
+ if ubias.ndim > 2:
+ ubias = torch.sum(ubias, dim=tuple(range(2, ubias.ndim)))
+ # Also adjust the missing ubias term.
+ if uA.ndim > self.max_input.ndim:
+ A = torch.sum(uA, dim=tuple(range(self.max_input.ndim, uA.ndim)))
+ else:
+ A = uA
+
+ # These should hold true in loss fusion
+ assert self.batch_dim == 0
+ assert A.shape[0] == 1
+
+ batch_size = A.shape[1]
+ ubias -= (A.reshape(batch_size, -1) * self.max_input.reshape(batch_size, -1)).sum(dim=-1).unsqueeze(0)
+ return [(None, uA)], 0, ubias
+ else:
+ return super().bound_backward(last_lA, last_uA, x)
+
+ def bound_relax(self, x):
+ min_val = -1e9
+ l, u = x.lower.clamp(min=min_val), x.upper.clamp(min=min_val)
+ m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99)
+ exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper)
+ k = exp_m
+ self.add_linear_relaxation(mask=None, type='lower', k=k, x0=m, y0=exp_m)
+ min_val = -1e9 # to avoid (-inf)-(-inf) when both input.lower and input.upper are -inf
+ epsilon = 1e-20
+ close = (u - l < epsilon).int()
+ k = close * exp_u + (1 - close) * (exp_u - exp_l) / (u - l + epsilon)
+ self.add_linear_relaxation(mask=None, type='upper', k=k, x0=l, y0=exp_l)
+
+
+class BoundLog(BoundActivation):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+
+ def forward(self, x):
+ # NOTE adhoc implementation for loss fusion
+ if self.loss_fusion:
+ return torch.logsumexp(self.inputs[0].inputs[0].inputs[0].forward_value, dim=-1)
+ return torch.log(x.clamp(min=epsilon))
+
+ def bound_relax(self, x):
+ rl, ru = self.forward(x.lower), self.forward(x.upper)
+ ku = (ru - rl) / (x.upper - x.lower + epsilon)
+ self.add_linear_relaxation(mask=None, type='lower', k=ku, x0=x.lower, y0=rl)
+ m = (x.lower + x.upper) / 2
+ k = torch.reciprocal(m)
+ rm = self.forward(m)
+ self.add_linear_relaxation(mask=None, type='upper', k=k, x0=m, y0=rm)
+
+ def interval_propagate(self, *v):
+ # NOTE adhoc implementation for loss fusion
+ if self.loss_fusion:
+ par = self.inputs[0].inputs[0].inputs[0]
+ lower = torch.logsumexp(par.lower, dim=-1)
+ upper = torch.logsumexp(par.upper, dim=-1)
+ return lower, upper
+ return super().interval_propagate(*v)
+
+ def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None):
+ A, lbias, ubias = super().bound_backward(last_lA, last_uA, x)
+ # NOTE adhoc implementation for loss fusion
+ if self.loss_fusion:
+ assert A[0][0] is None
+ exp_module = self.inputs[0].inputs[0]
+ ubias = ubias + self.get_bias(A[0][1], exp_module.max_input.squeeze(-1))
+ return A, lbias, ubias
+
+
+class BoundPow(BoundActivation):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+
+ def forward(self, x, y):
+ return torch.pow(x, y)
+
+ def bound_backward(self, last_lA, last_uA, x, y):
+ assert not self.is_input_perturbed(1)
+ y = y.lower.item()
+ if y == int(y) and y == 2:
+ x_l = x.lower
+ x_u = torch.max(x.upper, x.lower + 1e-8)
+
+ pow_l = self.forward(x_l, y)
+ pow_u = self.forward(x_u, y)
+ k_u = (pow_u - pow_l) / (x_u - x_l).clamp(min=1e-8)
+ b_u = pow_l - k_u * x_l
+
+ k_l = torch.zeros_like(k_u)
+ b_l = torch.zeros_like(b_u)
+ x_m = (x_l + x_u) / 2
+
+ # TODO this only holds for y=2
+ x_m = (x_u < 0) * torch.max(x_m, x_u * 2) + (x_l > 0) * torch.min(x_m, x_l * 2)
+ k_l = y * self.forward(x_m, y - 1)
+ b_l = self.forward(x_m, y) - k_l * x_m
+
+ if last_lA is not None:
+ last_lA_pos, last_lA_neg = last_lA.clamp(min=0), last_lA.clamp(max=0)
+ lA = last_lA_pos * k_l + last_lA_neg * k_u
+ lb = self.get_bias(last_lA_pos, b_l) + self.get_bias(last_lA_neg, b_u)
+ else:
+ lA, lb = None, 0
+
+ if last_uA is not None:
+ last_uA_pos, last_uA_neg = last_uA.clamp(min=0), last_uA.clamp(max=0)
+ uA = last_uA_pos * k_u + last_uA_neg * k_l
+ ub = self.get_bias(last_uA_pos, b_u) + self.get_bias(last_uA_neg, b_l)
+ else:
+ uA, ub = None, 0
+
+ return [(lA, uA), (None, None)], lb, ub
+ else:
+ raise NotImplementedError(f'Exponent {y} is not supported yet')
+
+ def interval_propagate(self, *v):
+ assert not self.is_input_perturbed(1)
+ exp = v[1][0]
+ assert exp == int(exp)
+ exp = int(exp)
+ pl, pu = torch.pow(v[0][0], exp), torch.pow(v[0][1], exp)
+ if exp % 2 == 1:
+ return pl, pu
+ else:
+ pl, pu = torch.min(pl, pu), torch.max(pl, pu)
+ mask = 1 - ((v[0][0] < 0) * (v[0][1] > 0)).to(pl.dtype)
+ return pl * mask, pu
+
+
+class BoundReciprocal(BoundActivation):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+
+ def forward(self, x):
+ return torch.reciprocal(x)
+
+ def bound_relax(self, x):
+ m = (x.lower + x.upper) / 2
+ kl = -1 / m.pow(2)
+ self.add_linear_relaxation(mask=None, type='lower', k=kl, x0=m, y0=1. / m)
+ ku = -1. / (x.lower * x.upper)
+ self.add_linear_relaxation(mask=None, type='upper', k=ku, x0=x.lower, y0=1. / x.lower)
+
+ def interval_propagate(self, *v):
+ h_L, h_U = v[0][0].float(), v[0][1].float()
+ assert h_L.min() > 0, 'Only positive values are supported in BoundReciprocal'
+ return torch.reciprocal(h_U), torch.reciprocal(h_L)
+
+
+class BoundSqrt(BoundActivation):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+
+ def forward(self, x):
+ return torch.sqrt(x)
+
+ def interval_propagate(self, *v):
+ return super().interval_propagate(*v)
+
+ def bound_backward(self, last_lA, last_uA, x):
+ x_l = x.lower
+ x_u = torch.max(x.upper, x.lower + 1e-8)
+ sqrt_l = self.forward(x_l)
+ sqrt_u = self.forward(x_u)
+ k_l = (sqrt_u - sqrt_l) / (x_u - x_l).clamp(min=1e-8)
+ b_l = sqrt_l - k_l * x_l
+
+ x_m = (x_l + x_u) / 2
+ sqrt_m = self.forward(x_m)
+ k_u = -0.5 * torch.pow(x_m, -1.5)
+ b_u = sqrt_m - k_u * x_m
+
+ # TODO make this part a general function
+ if last_lA is not None:
+ last_lA_pos, last_lA_neg = last_lA.clamp(min=0), last_lA.clamp(max=0)
+ lA = last_lA_pos * k_l + last_lA_neg * k_u
+ lb = self.get_bias(last_lA_pos, b_l) + self.get_bias(last_lA_neg, b_u)
+ else:
+ lA, lb = None, 0
+ if last_uA is not None:
+ last_uA_pos, last_uA_neg = last_uA.clamp(min=0), last_uA.clamp(max=0)
+ uA = last_uA_pos * k_u + last_uA_neg * k_l
+ ub = self.get_bias(last_uA_pos, b_u) + self.get_bias(last_uA_neg, b_l)
+ else:
+ uA, ub = None, 0
+
+ return [(lA, uA), (None, None)], lb, ub
+
+
+class BoundSqr(BoundActivation):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.nonlinear = True
+
+ def forward(self, x):
+ return x**2
+
+ def bound_backward(self, last_lA, last_uA, x):
+ x_L, x_U = x.lower, x.upper
+ upper_k = x_U + x_L
+ upper_b = x_L**2 - upper_k * x_L
+ if last_uA is not None:
+ # Special case if we only want the upper bound with non-negative
+ # coefficients.
+ if last_uA.min() >= 0:
+ uA = last_uA * upper_k
+ ubias = self.get_bias(last_uA, upper_b)
+ else:
+ raise NotImplementedError
+ else:
+ uA, ubias = None, 0
+ if last_lA is not None:
+ if last_lA.max() <= 0:
+ lA = last_lA * upper_k
+ lbias = self.get_bias(last_lA, upper_b)
+ else:
+ raise NotImplementedError
+ else:
+ lA, lbias = None, 0
+ return [(lA, uA)], lbias, ubias
+
+ def interval_propagate(self, *v):
+ h_L, h_U = v[0][0], v[0][1]
+ lower = ((h_U < 0) * (h_U**2) + (h_L > 0) * (h_L**2))
+ upper = torch.max(h_L**2, h_U**2)
+ return lower, upper
diff --git a/auto_LiRPA/operators/normalization.py b/auto_LiRPA/operators/normalization.py
index 75b669b..d1050ed 100644
--- a/auto_LiRPA/operators/normalization.py
+++ b/auto_LiRPA/operators/normalization.py
@@ -1,23 +1,25 @@
""" Normalization operators"""
from .base import *
+from .solver_utils import grb
class BoundBatchNormalization(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device, training):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options, training):
+ super().__init__(attr, inputs, output_index, options)
self.eps = attr['epsilon']
self.momentum = round(1 - attr['momentum'], 5) # take care!
- self.mode = options.get("conv_mode", "matrix")
self.options = options.get("bn", {})
# modes:
# - forward: use mean and variance estimated from clean forward pass
# - ibp: use mean and variance estimated from ibp
- self.bn_mode = self.options.get("mode", "forward")
+ self.bn_mode = self.options.get("mode", "forward")
self.use_mean = self.options.get("mean", True)
self.use_var = self.options.get("var", True)
self.use_affine = self.options.get("affine", True)
self.training = training
+ self.patches_start = True
+ self.mode = options.get("conv_mode", "matrix")
if not self.use_mean or not self.use_var:
- logger.info(f'Batch normalization node {self.name}: use_mean {self.use_mean}, use_var {self.use_var}')
+ logger.info(f'Batch normalization node {self.name}: use_mean {self.use_mean}, use_var {self.use_var}')
def _check_unused_mean_or_var(self):
# Check if either mean or var is opted out
@@ -26,8 +28,9 @@ def _check_unused_mean_or_var(self):
if not self.use_var:
self.current_var = torch.ones_like(self.current_var)
- @Bound.save_io_shape
def forward(self, x, w, b, m, v):
+ if len(x.shape) == 2:
+ self.patches_start = False
if self.training:
dim = [0] + list(range(2, x.ndim))
self.current_mean = x.mean(dim)
@@ -35,10 +38,10 @@ def forward(self, x, w, b, m, v):
else:
self.current_mean = m.data
self.current_var = v.data
- self._check_unused_mean_or_var()
+ self._check_unused_mean_or_var()
if not self.use_affine:
w = torch.ones_like(w)
- b = torch.zeros_like(b)
+ b = torch.zeros_like(b)
result = F.batch_norm(x, m, v, w, b, self.training, self.momentum, self.eps)
if not self.use_mean or not self.use_var:
# If mean or variance is disabled, recompute the output from self.current_mean
@@ -61,15 +64,15 @@ def bound_backward(self, last_lA, last_uA, *x):
self._check_unused_mean_or_var()
if not self.use_affine:
weight = torch.ones_like(weight)
- bias = torch.zeros_like(bias)
-
+ bias = torch.zeros_like(bias)
+
tmp_bias = bias - self.current_mean / torch.sqrt(self.current_var + self.eps) * weight
tmp_weight = weight / torch.sqrt(self.current_var + self.eps)
def _bound_oneside(last_A):
if last_A is None:
return None, 0
- if type(last_A) == torch.Tensor:
+ if type(last_A) == Tensor:
next_A = last_A * tmp_weight.view(*((1, 1, -1) + (1,) * (last_A.ndim - 3)))
if last_A.ndim > 3:
sum_bias = (last_A.sum(tuple(range(3, last_A.ndim))) * tmp_bias).sum(2)
@@ -85,11 +88,12 @@ def _bound_oneside(last_A):
# tmp_weight has shape (c,), it will be applied on the (c,) dimension.
patches = patches * tmp_weight.view(*([1] * (patches.ndim - 3)), -1, 1, 1) # Match with sparse or non-sparse patches.
next_A = Patches(patches, last_A.stride, last_A.padding, last_A.shape, identity=0, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape)
-
+
# bias to size (c,), need expansion before unfold.
bias = tmp_bias.view(-1,1,1).expand(self.input_shape[1:]).unsqueeze(0)
# Unfolded bias has shape (1, out_h, out_w, in_c, H, W).
- bias_unfolded = inplace_unfold(bias, kernel_size=last_A.patches.shape[-2:], padding=last_A.padding, stride=last_A.stride)
+ bias_unfolded = inplace_unfold(bias, kernel_size=last_A.patches.shape[-2:], padding=last_A.padding, stride=last_A.stride,
+ inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding)
if last_A.unstable_idx is not None:
# Sparse bias has shape (unstable_size, batch, in_c, H, W).
bias_unfolded = bias_unfolded[:, last_A.unstable_idx[1], last_A.unstable_idx[2]]
@@ -130,6 +134,7 @@ def _bound_oneside(last_A):
return [(lA, uA), (None, None), (None, None), (None, None), (None, None)], lbias, ubias
+
def interval_propagate(self, *v):
assert not self.is_input_perturbed(1) and not self.is_input_perturbed(2), \
'Weight perturbation is not supported for BoundBatchNormalization'
@@ -153,17 +158,75 @@ def interval_propagate(self, *v):
self._check_unused_mean_or_var()
if not self.use_affine:
weight = torch.ones_like(weight)
- bias = torch.zeros_like(bias)
+ bias = torch.zeros_like(bias)
tmp_weight = weight / torch.sqrt(self.current_var + self.eps)
tmp_weight_abs = tmp_weight.abs()
tmp_bias = bias - self.current_mean * tmp_weight
-
shape = (1, -1) + (1,) * (mid.ndim - 2)
- center = tmp_weight.view(*shape) * mid + tmp_bias.view(*shape)
- deviation = tmp_weight_abs.view(*shape) * diff
- lower = center - deviation
- upper = center + deviation
+
+ # interval_propagate() of the Linear layer may encounter input with different norms.
+ norm, eps = Interval.get_perturbation(v[0])[:2]
+ if norm == np.inf:
+ center = tmp_weight.view(*shape) * mid + tmp_bias.view(*shape)
+ deviation = tmp_weight_abs.view(*shape) * diff
+ elif norm > 0:
+ mid = v[0][0]
+ center = tmp_weight.view(*shape) * mid + tmp_bias.view(*shape)
+ if norm == 2:
+ ptb = copy.deepcopy(v[0].ptb)
+ ptb.eps = eps * tmp_weight_abs.max()
+ return Interval(center, center, ptb=ptb)
+ else:
+ # General Lp norm.
+ center = tmp_weight.view(*shape) * mid
+ deviation = tmp_weight_abs.view(*shape) * eps # use a Linf ball to replace Lp norm
+ else:
+ raise NotImplementedError
+
+ lower, upper = center - deviation, center + deviation
return lower, upper
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ # e.g., last layer input gurobi vars (3,32,32)
+ gvars_array = np.array(v[0])
+ # pre_layer_shape (1,3,32,32)
+ pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape
+ # this layer shape (1,8,16,16)
+ this_layer_shape = self.output_shape
+
+ weight, bias = v[1], v[2]
+
+ self.current_mean = v[3]
+ self.current_var = v[4]
+ self._check_unused_mean_or_var()
+ if not self.use_affine:
+ weight = torch.ones_like(weight)
+ bias = torch.zeros_like(bias)
+
+ tmp_bias = bias - self.current_mean / torch.sqrt(self.current_var + self.eps) * weight
+ tmp_weight = weight / torch.sqrt(self.current_var + self.eps)
+
+ new_layer_gurobi_vars = []
+ neuron_idx = 0
+ for out_chan_idx in range(this_layer_shape[1]):
+ out_chan_vars = []
+ for out_row_idx in range(this_layer_shape[2]):
+ out_row_vars = []
+ for out_col_idx in range(this_layer_shape[3]):
+ # print(this_layer_bias.shape, out_chan_idx, out_lbs.size(1))
+ lin_expr = tmp_bias[out_chan_idx].item() + tmp_weight[out_chan_idx].item() * gvars_array[out_chan_idx, out_row_idx, out_col_idx]
+ var = model.addVar(lb=-float('inf'), ub=float('inf'),
+ obj=0, vtype=grb.GRB.CONTINUOUS,
+ name=f'lay{self.name}_{neuron_idx}')
+ model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq')
+ neuron_idx += 1
+
+ out_row_vars.append(var)
+ out_chan_vars.append(out_row_vars)
+ new_layer_gurobi_vars.append(out_chan_vars)
+
+ self.solver_vars = new_layer_gurobi_vars
+ # self.solver_constrs = new_layer_gurobi_constrs
+ model.update()
diff --git a/auto_LiRPA/operators/pooling.py b/auto_LiRPA/operators/pooling.py
new file mode 100644
index 0000000..6163caa
--- /dev/null
+++ b/auto_LiRPA/operators/pooling.py
@@ -0,0 +1,550 @@
+""" Convolution, pooling and padding operators"""
+from multiprocessing.sharedctypes import Value
+from .base import *
+from .activation import BoundOptimizableActivation
+import numpy as np
+from .solver_utils import grb
+
+
+class BoundMaxPool(BoundOptimizableActivation):
+ #FIXME clean up needed
+
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2])
+ assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3])
+
+ self.requires_input_bounds = [0]
+ self.kernel_size = attr['kernel_shape']
+ self.stride = attr['strides']
+ self.padding = [attr['pads'][0], attr['pads'][1]]
+ self.ceil_mode = False
+ self.use_default_ibp = True
+ self.alpha = None
+ self.init = {}
+ self.alpha_batch_dim = 2
+
+ def forward(self, x):
+ output, _ = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode)
+ return output
+
+ def project_simplex(self, patches):
+ sorted = torch.flatten(patches, -2)
+ sorted, _ = torch.sort(sorted, -1, descending=True)
+ rho_sum = torch.cumsum(sorted, -1)
+ rho_value = 1 - rho_sum
+ rho_value = (sorted + rho_value/torch.tensor(range(1, sorted.size(-1)+1), dtype=torch.float, device=sorted.device)) > 0
+ _, rho_index = torch.max(torch.cumsum(rho_value, -1), -1)
+ rho_sum = torch.gather(rho_sum, -1, rho_index.unsqueeze(-1)).squeeze(-1)
+ lbd = 1/(rho_index+1)* (1-rho_sum)
+
+ return torch.clamp(patches + lbd.unsqueeze(-1).unsqueeze(-1), min=0)
+
+ def init_opt_parameters(self, start_nodes):
+ self.alpha = OrderedDict()
+ ref = self.inputs[0].lower # a reference variable for getting the shape
+ for ns, size_s, unstable_idx in start_nodes:
+ if ns == '_forward':
+ warnings.warn("MaxPool's optimization is not supported for forward mode")
+ continue
+ self.alpha[ns] = torch.empty(
+ [1, size_s, self.input_shape[0], self.input_shape[1],
+ self.output_shape[-2], self.output_shape[-1],
+ self.kernel_size[0], self.kernel_size[1]],
+ dtype=torch.float, device=ref.device, requires_grad=True)
+ self.init[ns] = False
+
+ @staticmethod
+ @torch.jit.script
+ def jit_mutiply(Apos, Aneg, pos, neg):
+ return pos.contiguous() * Apos + neg.contiguous() * Aneg
+
+ def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None, unstable_idx=None):
+ # self.padding is a tuple of two elements: (height dimension padding, width dimension padding).
+ paddings = tuple((self.padding[0], self.padding[0], self.padding[1], self.padding[1]))
+
+ if self.stride[0] != self.kernel_size[0]:
+ raise ValueError("self.stride ({}) != self.kernel_size ({})".format(self.stride, self.kernel_size))
+ if self.padding[0] != 0:
+ raise ValueError("BoundMaxPool doesn't support padding != 0")
+
+ shape = self.input_shape
+ batch_size = x.lower.shape[0]
+ shape = list(shape[:-2]) + [a + 2*b for a, b in zip(self.input_shape[-2:], self.padding)]
+ shape[0] = batch_size
+ # Lower and upper D matrices. They have size (batch_size, input_c, x, y) which will be multiplied on enlarges the A matrices via F.interpolate.
+ upper_d = torch.zeros(shape, device=x.device)
+ lower_d = None
+
+ # Size of upper_b and lower_b: (batch_size, output_c, h, w).
+ upper_b = torch.zeros(batch_size, *self.output_shape[1:], device=x.device)
+ lower_b = torch.zeros(batch_size, *self.output_shape[1:], device=x.device)
+
+ # Find the maxpool neuron whose input bounds satisfy l_i > max_j u_j for all j != i. In this case, the maxpool neuron is linear, and we can set upper_d = lower_d = 1.
+ # We first find which indices has the largest lower bound.
+ max_lower, max_lower_index = F.max_pool2d(x.lower, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode)
+ # Set the upper bound of the i-th input to -inf so it will not be selected as the max.
+
+ if paddings == (0,0,0,0):
+ delete_upper = torch.scatter(
+ torch.flatten(x.upper, -2), -1,
+ torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape)
+ else:
+ delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape)
+ # Find the the max upper bound over the remaining ones.
+ max_upper, _ = F.max_pool2d(delete_upper, self.kernel_size, self.stride, 0, return_indices=True, ceil_mode=self.ceil_mode)
+
+ # The upper bound slope for maxpool is either 1 on input satisfies l_i > max_j u_j (linear), or 0 everywhere. Upper bound is not optimized.
+ values = torch.zeros_like(max_lower)
+ values[max_lower >= max_upper] = 1.0
+ upper_d = torch.scatter(torch.flatten(upper_d, -2), -1, torch.flatten(max_lower_index, -2), torch.flatten(values, -2)).view(upper_d.shape)
+
+ if self.opt_stage == 'opt':
+ if unstable_idx is not None and self.alpha[start_node.name].size(1) != 1:
+ if unstable_idx.ndim == 1:
+ # Only unstable neurons of the start_node neurons are used.
+ alpha = self.non_deter_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1)
+ elif unstable_idx.ndim == 2:
+ # Each element in the batch selects different neurons.
+ alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1)
+ else:
+ raise ValueError
+ else:
+ alpha = self.alpha[start_node.name]
+
+ if not self.init[start_node.name]:
+ lower_d = torch.zeros((shape), device=x.device)
+ # [batch, C, H, W]
+ lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape)
+ # shape [batch, C*k*k, L]
+ lower_d_unfold = F.unfold(lower_d, self.kernel_size, 1, stride=self.stride)
+
+ # [batch, C, k, k, out_H, out_W]
+ alpha_data = lower_d_unfold.view(lower_d.shape[0], lower_d.shape[1], self.kernel_size[0], self.kernel_size[1], self.output_shape[-2], self.output_shape[-1])
+
+ # [batch, C, out_H, out_W, k, k]
+ alpha.data.copy_(alpha_data.permute((0,1,4,5,2,3)).clone().detach())
+ self.init[start_node.name] = True
+ # In optimization mode, we use the same lower_d once builded.
+ if self.padding[0] > 0 or self.padding[1] > 0:
+ lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[1]:-self.padding[1]]
+ # The lower bound coefficients must be positive and projected to an unit simplex.
+ alpha.data = self.project_simplex(alpha.data).clone().detach() # TODO: don't do this, never re-assign the .data property. Use copy_ instead.
+ # permute the last 6 dimensions of alpha to [batch, C, k, k, out_H, out_W], which prepares for the unfold operation.
+ alpha = alpha.permute((0,1,2,3,6,7,4,5))
+ alpha_shape = alpha.shape
+ alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2], -1, alpha_shape[-2]*alpha_shape[-1]))
+ lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1, self.padding, self.stride)
+ lower_d = lower_d.view(alpha_shape[0], alpha_shape[1], alpha_shape[2], *lower_d.shape[1:])
+ lower_d = lower_d.squeeze(0)
+ else:
+ lower_d = torch.zeros((shape), device=x.device)
+ # Not optimizable bounds. We simply set \hat{z} >= z_i where i is the input element with largest lower bound.
+ lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape)
+ if self.padding[0] > 0 or self.padding[1] > 0:
+ lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[1]:-self.padding[1]]
+
+ # For the upper bound, we set the bias term to concrete upper bounds for maxpool neurons that are not linear.
+ max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode)
+ upper_b[max_upper > max_lower] = max_upper_[max_upper > max_lower]
+
+ def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg):
+ if last_A is None:
+ return None, 0
+
+ bias = 0
+
+ if isinstance(last_A, torch.Tensor):
+ pos_A = last_A.clamp(min=0)
+ neg_A = last_A.clamp(max=0)
+
+ if b_pos is not None:
+ # This is matrix mode, and padding is considered in the previous layers
+ bias = bias + self.get_bias(pos_A, b_pos)
+ if b_neg is not None:
+ bias = bias + self.get_bias(neg_A, b_neg)
+
+ # Here we should comfirm that the maxpool patches are not overlapped.
+ shape = last_A.size()
+ pos_A = F.interpolate(pos_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size)
+ if self.input_shape[-2] != pos_A.shape[-2] and self.input_shape[-1] != pos_A.shape[-1]:
+ pos_A = F.pad(pos_A, (0, self.input_shape[-2] - pos_A.shape[-2], 0, self.input_shape[-1] - pos_A.shape[-1]))
+ pos_A = pos_A.view(shape[0], shape[1], *pos_A.shape[1:])
+
+ neg_A = F.interpolate(neg_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size)
+ if self.input_shape[-2] != neg_A.shape[-2] and self.input_shape[-1] != neg_A.shape[-1]:
+ neg_A = F.pad(neg_A, (0, self.input_shape[-2] - neg_A.shape[-2], 0, self.input_shape[-1] - neg_A.shape[-1]))
+ neg_A = neg_A.view(shape[0], shape[1], *neg_A.shape[1:])
+
+ next_A = self.jit_mutiply(pos_A, neg_A, d_pos, d_neg)
+ elif isinstance(last_A, Patches):
+ if b_pos is not None:
+ patch_pos = Patches(
+ last_A.patches.clamp(min=0), last_A.stride, last_A.padding,
+ last_A.shape, unstable_idx=last_A.unstable_idx,
+ output_shape=last_A.output_shape)
+ bias = bias + self.get_bias(patch_pos, b_pos)
+ if b_neg is not None:
+ patch_neg = Patches(
+ last_A.patches.clamp(max=0), last_A.stride, last_A.padding,
+ last_A.shape, unstable_idx=last_A.unstable_idx,
+ output_shape=last_A.output_shape)
+ bias = bias + self.get_bias(patch_neg, b_neg)
+
+ # bias = bias.transpose(0,1)
+ shape = last_A.shape
+ pos_A = last_A.patches.clamp(min=0)
+ neg_A = last_A.patches.clamp(max=0)
+
+ def upsample(last_patches, last_A):
+ if last_A.unstable_idx is None:
+ patches = F.interpolate(
+ last_patches.view(shape[0] * shape[1] * shape[2], *shape[3:]),
+ scale_factor=[1,]+self.kernel_size)
+ patches = patches.view(shape[0], shape[1], shape[2], *patches.shape[1:])
+ else:
+ patches = F.interpolate(
+ last_patches, scale_factor=[1,] + self.kernel_size)
+ return Patches(
+ patches, stride=last_A.stride, padding=last_A.padding,
+ shape=patches.shape, unstable_idx=last_A.unstable_idx,
+ output_shape=last_A.output_shape)
+
+ pos_A = upsample(pos_A, last_A)
+ neg_A = upsample(neg_A, last_A)
+
+ stride = self.stride[0] * last_A.stride
+ if isinstance(last_A.padding, int):
+ padding = last_A.padding * self.stride[0] + self.padding[0]
+ else:
+ # Here we need to unfold the d_pos to match pos_A and neg_A patches
+ # And we compute the padding and stride of pos_A and neg_A
+ padding = tuple(a * self.stride[0] + self.padding[0] for a in last_A.padding)
+
+ padding, stride, output_padding = compute_patches_stride_padding(
+ self.input_shape, last_A.padding, last_A.stride, self.padding, self.stride, last_A.inserted_zeros, last_A.output_padding)
+
+ pos_A.padding, pos_A.stride, pos_A.output_padding = padding, stride, output_padding
+ neg_A.padding, neg_A.stride, neg_A.output_padding = padding, stride, output_padding
+
+ # unsqueeze for the spec dimension
+ d_pos = maybe_unfold_patches(d_pos.unsqueeze(0), pos_A)
+ d_neg = maybe_unfold_patches(d_neg.unsqueeze(0), neg_A)
+
+
+ next_A_patches = self.jit_mutiply(pos_A.patches, neg_A.patches, d_pos, d_neg)
+
+ if start_node is not None:
+ self.patch_size[start_node.name] = next_A_patches.size()
+
+
+ next_A = Patches(
+ next_A_patches, stride, padding, next_A_patches.shape,
+ unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape,
+ inserted_zeros=last_A.inserted_zeros, output_padding=output_padding)
+
+ return next_A, bias
+
+ if self.padding[0] > 0:
+ upper_d = upper_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]]
+
+ uA, ubias = _bound_oneside(last_uA, upper_d, lower_d, upper_b, lower_b)
+ lA, lbias = _bound_oneside(last_lA, lower_d, upper_d, lower_b, upper_b)
+
+ return [(lA, uA)], lbias, ubias
+
+ def bound_forward(self, dim_in, x):
+ lower_d, lower_b, upper_d, upper_b = self.bound_relax(x)
+
+ def _bound_oneside(w_pos, b_pos, w_neg, b_neg, d, b):
+ d_pos, d_neg = d.clamp(min=0), d.clamp(max=0)
+ w_new = d_pos.unsqueeze(1) * w_pos + d_neg.unsqueeze(1) * w_neg
+ b_new = d_pos * b_pos + d_neg * b_neg
+ if isinstance(self.kernel_size, list) and len(self.kernel_size) == 2:
+ tot_kernel_size = prod(self.kernel_size)
+ elif isinstance(self.kernel_size, int):
+ tot_kernel_size = self.kernel_size ** 2
+ else:
+ raise ValueError(f'Unsupported kernel size {self.kernel_size}')
+ w_pooled = (F.avg_pool2d(w_new.view(-1, *w_new.shape[2:]),
+ self.kernel_size, self.stride, self.padding,
+ ceil_mode=self.ceil_mode) * tot_kernel_size)
+ w_pooled = w_pooled.reshape(w_new.shape[0], -1, *w_pooled.shape[1:])
+ b_pooled = F.avg_pool2d(b_new, self.kernel_size, self.stride, self.padding,
+ ceil_mode=self.ceil_mode) * tot_kernel_size + b
+ return w_pooled, b_pooled
+
+ lw, lb = _bound_oneside(x.lw, x.lb, x.uw, x.ub, lower_d, lower_b)
+ uw, ub = _bound_oneside(x.uw, x.ub, x.lw, x.lb, upper_d, upper_b)
+
+ return LinearBound(lw, lb, uw, ub)
+
+ def bound_relax(self, x):
+ # Only used by forward mode
+ paddings = tuple(self.padding + self.padding)
+ self.upper, self.lower = x.upper, x.lower
+
+ # A_shape = last_lA.shape if last_lA is not None else last_uA.shape
+ # batch_size, input_c, x, y
+ upper_d = torch.zeros_like(x.lower)
+ lower_d = torch.zeros_like(x.lower)
+
+ upper_d = F.pad(upper_d, paddings)
+ lower_d = F.pad(lower_d, paddings)
+
+ # batch_size, output_c, x, y
+ upper_b = torch.zeros((list(self.output_shape))).to(x.lower)
+ lower_b = torch.zeros((list(self.output_shape))).to(x.lower)
+
+ # 1. find the index i where li > uj for all j, then set upper_d = lower_d = 1
+ max_lower, max_lower_index = F.max_pool2d(x.lower, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode)
+ delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape)
+ max_upper, _ = F.max_pool2d(delete_upper, self.kernel_size, self.stride, 0, return_indices=True, ceil_mode=self.ceil_mode)
+
+ values = torch.zeros_like(max_lower)
+ values[max_lower >= max_upper] = 1.0
+ upper_d = torch.scatter(torch.flatten(upper_d, -2), -1, torch.flatten(max_lower_index, -2), torch.flatten(values, -2)).view(upper_d.shape)
+
+ # FIXME shape error
+ if False and self.opt_stage == 'opt':
+ alpha = self.alpha[self._start]
+
+ if self.init[self._start] == False:
+ lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape)
+ lower_d_unfold = F.unfold(lower_d, self.kernel_size, 1, stride=self.stride)
+
+ alpha_data = lower_d_unfold.view(lower_d.shape[0], lower_d.shape[1], self.kernel_size[0], self.kernel_size[1], self.output_shape[-2], self.output_shape[-1])
+ alpha.data.copy_(alpha_data.permute((0,1,4,5,2,3)).clone().detach())
+ self.init[self._start] = True
+ if self.padding[0] > 0:
+ lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]]
+
+ alpha.data = self.project_simplex(alpha.data).clone().detach()
+ alpha = alpha.permute((0,1,2,3,6,7,4,5))
+ alpha_shape = alpha.shape
+ alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2], -1, alpha_shape[-2]*alpha_shape[-1]))
+ lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1, self.padding, self.stride)
+ lower_d = lower_d.view(alpha_shape[0], alpha_shape[1], alpha_shape[2], *lower_d.shape[1:])
+ lower_d = lower_d.squeeze(0)
+ else:
+ lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape)
+ if self.padding[0] > 0:
+ lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]]
+
+ values[:] = 0.0
+ max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode)
+ values[max_upper > max_lower] = max_upper_[max_upper > max_lower]
+ upper_b = values
+
+ if self.padding[0] > 0:
+ upper_d = upper_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]]
+
+ return lower_d, lower_b, upper_d, upper_b
+
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ # e.g., last layer input gurobi vars (3,32,32)
+ gvars_array = np.array(v[0])
+ # pre_layer_shape (1,32,27,27)
+ pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape
+ # this layer shape (1,32,6,6)
+ this_layer_shape = self.output_shape
+ assert this_layer_shape[2] == ((2 * self.padding[0] + pre_layer_shape[2] - (self.stride[0] - 1))//self.stride[0])
+
+ new_layer_gurobi_vars = []
+ neuron_idx = 0
+ pre_ubs = self.forward(self.inputs[0].upper).detach().cpu().numpy()
+
+ for out_chan_idx in range(this_layer_shape[1]):
+ out_chan_vars = []
+ for out_row_idx in range(this_layer_shape[2]):
+ out_row_vars = []
+ for out_col_idx in range(this_layer_shape[3]):
+ a_sum = 0.0
+ v = model.addVar(lb=-float('inf'), ub=float('inf'),
+ obj=0, vtype=grb.GRB.CONTINUOUS,
+ name=f'lay{self.name}_{neuron_idx}')
+ for ker_row_idx in range(self.kernel_size[0]):
+ in_row_idx = -self.padding[0] + self.stride[0] * out_row_idx + ker_row_idx
+ if (in_row_idx < 0) or (in_row_idx == len(gvars_array[out_chan_idx][ker_row_idx])):
+ # This is padding -> value of 0
+ continue
+ for ker_col_idx in range(self.kernel_size[1]):
+ in_col_idx = -self.padding[1] + self.stride[1] * out_col_idx + ker_col_idx
+ if (in_col_idx < 0) or (in_col_idx == pre_layer_shape[3]):
+ # This is padding -> value of 0
+ continue
+ var = gvars_array[out_chan_idx][in_row_idx][in_col_idx]
+ a = model.addVar(vtype=grb.GRB.BINARY)
+ a_sum += a
+ model.addConstr(v >= var)
+ model.addConstr(v <= var + (1 - a) * pre_ubs[0, out_chan_idx, out_row_idx, out_col_idx])
+ model.addConstr(a_sum == 1, name=f'lay{self.name}_{neuron_idx}_eq')
+ out_row_vars.append(v)
+ out_chan_vars.append(out_row_vars)
+ new_layer_gurobi_vars.append(out_chan_vars)
+
+ self.solver_vars = new_layer_gurobi_vars
+ model.update()
+
+
+
+class BoundGlobalAveragePool(Bound):
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+
+ def forward(self, x):
+ output = nn.AdaptiveAvgPool2d((1, 1)).forward(x) # adaptiveAveragePool with output size (1, 1)
+ return output
+
+ def bound_backward(self, last_lA, last_uA, x):
+ H, W = self.input_shape[-2], self.input_shape[-1]
+
+ lA = (last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)) if last_lA is not None else None
+ uA = (last_uA.expand(list(last_uA.shape[:-2]) + [H, W]) / (H * W)) if last_uA is not None else None
+
+ return [(lA, uA)], 0, 0
+
+ def interval_propagate(self, *v):
+ h_L, h_U = v[0]
+ h_L = F.adaptive_avg_pool2d(h_L, (1, 1))
+ h_U = F.adaptive_avg_pool2d(h_U, (1, 1))
+ return h_L, h_U
+
+
+class BoundAveragePool(Bound):
+ def __init__(self, attr, inputs, output_index, options):
+ # assumptions: ceil_mode=False, count_include_pad=True
+ super().__init__(attr, inputs, output_index, options)
+
+ assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2])
+ assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3])
+
+ self.kernel_size = attr['kernel_shape']
+ assert len(self.kernel_size) == 2
+ self.stride = attr['strides']
+ assert len(self.stride) == 2
+ # FIXME (22/07/02): padding is inconsistently handled. Should use 4-tuple.
+
+ if 'pads' not in attr:
+ self.padding = [0, 0]
+ else:
+ self.padding = [attr['pads'][0], attr['pads'][1]]
+ self.ceil_mode = False
+ self.count_include_pad = True
+ self.use_default_ibp = True
+
+ def forward(self, x):
+ return F.avg_pool2d(x, self.kernel_size, self.stride,
+ self.padding, self.ceil_mode, self.count_include_pad)
+
+ def bound_backward(self, last_lA, last_uA, x):
+ def _bound_oneside(last_A):
+ if last_A is None:
+ return None, 0
+ if isinstance(last_A, torch.Tensor):
+ shape = last_A.size()
+ # propagate A to the next layer, with batch concatenated together
+ next_A = F.interpolate(last_A.view(shape[0] * shape[1], *shape[2:]),
+ scale_factor=self.kernel_size) / (prod(self.kernel_size))
+ next_A = F.pad(next_A, (0, self.input_shape[-2] - next_A.shape[-2], 0, self.input_shape[-1] - next_A.shape[-1]))
+ next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:])
+ elif isinstance(last_A, Patches):
+ patches = last_A.patches
+ shape = patches.size()
+ # When the number of inserted zeros can cancel out the stride, we use a shortcut that can reduce computation.
+ simplify_patch = (last_A.inserted_zeros + 1 == self.kernel_size[0]) and (self.kernel_size[0] == self.kernel_size[1])
+ padding, stride, output_padding = compute_patches_stride_padding(
+ self.input_shape, last_A.padding, last_A.stride, self.padding, self.stride,
+ inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding, simplify=not simplify_patch)
+ inserted_zeros = last_A.inserted_zeros
+ if last_A.inserted_zeros == 0:
+ # No inserted zeros, can be handled using interpolate.
+ if last_A.unstable_idx is None:
+ # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W]
+ up_sampled_patches = F.interpolate(patches.view(shape[0] * shape[1], shape[2] * shape[3], *shape[4:]), scale_factor=[1,] + self.kernel_size)
+ # The dimension of patch-H and patch_W has changed.
+ up_sampled_patches = up_sampled_patches.view(*shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1))
+ else:
+ # shape is: [spec, batch, in_c, patch_H, patch_W]
+ up_sampled_patches = F.interpolate(patches, scale_factor=[1,] + self.kernel_size)
+ # Divided by the averaging factor.
+ up_sampled_patches = up_sampled_patches / prod(self.kernel_size)
+ elif simplify_patch:
+ padding = tuple(p // s - o for p, s, o in zip(padding, stride, output_padding))
+ output_padding = (0, 0, 0, 0)
+ stride = 1 # Stride and inserted zero canceled out. No need to insert zeros and add output_padding.
+ inserted_zeros = 0
+ value = 1. / prod(self.kernel_size)
+ # In the case where the stride and adding_zeros cancel out, we do not need to insert zeros.
+ weight = torch.full(size=(self.input_shape[1], 1, *self.kernel_size), fill_value=value, dtype=patches.dtype, device=patches.device)
+ if last_A.unstable_idx is None:
+ # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W]
+ up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1] * shape[2] * shape[3], *shape[4:]), weight, stride=1, groups=self.input_shape[1])
+ else:
+ # shape is: [spec, batch, in_c, patch_H, patch_W]
+ up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1], *shape[2:]), weight, stride=1, groups=self.input_shape[1])
+ up_sampled_patches = up_sampled_patches.view(*shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1))
+ else:
+ # With inserted zeros, must be handled by treating pooling as general convolution.
+ value = 1. / prod(self.kernel_size)
+ weight = torch.full(size=(self.input_shape[1], 1, *self.kernel_size), fill_value=value, dtype=patches.dtype, device=patches.device)
+ weight = insert_zeros(weight, last_A.inserted_zeros)
+ if last_A.unstable_idx is None:
+ # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W]
+ up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1] * shape[2] * shape[3], *shape[4:]), weight, stride=self.kernel_size, groups=self.input_shape[1])
+ else:
+ # shape is: [spec, batch, in_c, patch_H, patch_W]
+ up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1], *shape[2:]), weight, stride=self.kernel_size, groups=self.input_shape[1])
+ up_sampled_patches = up_sampled_patches.view(*shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1))
+ next_A = last_A.create_similar(up_sampled_patches, stride=stride, padding=padding, output_padding=output_padding, inserted_zeros=inserted_zeros)
+ else:
+ raise ValueError(f'last_A has unexpected shape {type(last_A)}')
+ return next_A, 0.
+
+ lA, lbias = _bound_oneside(last_lA)
+ uA, ubias = _bound_oneside(last_uA)
+ return [(lA, uA)], lbias, ubias
+
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ # e.g., last layer input gurobi vars (3,32,32)
+ gvars_array = np.array(v[0])
+ # pre_layer_shape (1,32,27,27)
+ pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape
+ # this layer shape (1,32,6,6)
+ this_layer_shape = self.output_shape
+ assert this_layer_shape[2] == ((2 * self.padding[0] + pre_layer_shape[2] - (self.stride[0] - 1))//self.stride[0])
+
+ value = 1.0/(self.kernel_size[0] * self.kernel_size[1])
+ new_layer_gurobi_vars = []
+ neuron_idx = 0
+ for out_chan_idx in range(this_layer_shape[1]):
+ out_chan_vars = []
+ for out_row_idx in range(this_layer_shape[2]):
+ out_row_vars = []
+ for out_col_idx in range(this_layer_shape[3]):
+ # print(self.bias.shape, out_chan_idx, out_lbs.size(1))
+ lin_expr = 0.0
+ for ker_row_idx in range(self.kernel_size[0]):
+ in_row_idx = -self.padding[0] + self.stride[0] * out_row_idx + ker_row_idx
+ if (in_row_idx < 0) or (in_row_idx == len(gvars_array[out_chan_idx][ker_row_idx])):
+ # This is padding -> value of 0
+ continue
+ for ker_col_idx in range(self.kernel_size[1]):
+ in_col_idx = -self.padding[1] + self.stride[1] * out_col_idx + ker_col_idx
+ if (in_col_idx < 0) or (in_col_idx == pre_layer_shape[3]):
+ # This is padding -> value of 0
+ continue
+ coeff = value
+ lin_expr += coeff * gvars_array[out_chan_idx][in_row_idx][in_col_idx]
+ v = model.addVar(lb=-float('inf'), ub=float('inf'),
+ obj=0, vtype=grb.GRB.CONTINUOUS,
+ name=f'lay{self.name}_{neuron_idx}')
+ model.addConstr(lin_expr == v, name=f'lay{self.name}_{neuron_idx}_eq')
+ neuron_idx += 1
+
+ out_row_vars.append(v)
+ out_chan_vars.append(out_row_vars)
+ new_layer_gurobi_vars.append(out_chan_vars)
+
+ self.solver_vars = new_layer_gurobi_vars
+ model.update()
diff --git a/auto_LiRPA/operators/reduce.py b/auto_LiRPA/operators/reduce.py
index 246cc98..51da427 100644
--- a/auto_LiRPA/operators/reduce.py
+++ b/auto_LiRPA/operators/reduce.py
@@ -3,8 +3,8 @@
class BoundReduceMax(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.axis = attr['axes']
# for torch.max, `dim` must be an int
if isinstance(self.axis, list):
@@ -18,10 +18,8 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio
in Softmax of Transformers."""
self.fixed_max_index = options.get('fixed_reducemax_index', False)
- @Bound.save_io_shape
def forward(self, x):
- if self.axis < 0:
- self.axis += len(self.input_shape)
+ self.axis = self.make_axis_non_negative(self.axis)
assert self.axis > 0
res = torch.max(x, dim=self.axis, keepdim=self.keepdim)
self.indices = res.indices
@@ -53,20 +51,19 @@ def _bound_oneside(last_A):
class BoundReduceMean(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.axis = attr['axes']
self.keepdim = bool(attr['keepdims']) if 'keepdims' in attr else True
self.use_default_ibp = True
- @Bound.save_io_shape
def forward(self, x):
return torch.mean(x, dim=self.axis, keepdim=self.keepdim)
def bound_backward(self, last_lA, last_uA, x):
for i in range(len(self.axis)):
if self.axis[i] < 0:
- self.axis[i] = len(self.input_shape) + self.axis[i]
+ self.axis[i] = self.make_axis_non_negative(self.axis[i])
assert self.axis[i] > 0
def _bound_oneside(last_A):
@@ -89,9 +86,7 @@ def _bound_oneside(last_A):
def bound_forward(self, dim_in, x):
assert (self.keepdim)
assert (len(self.axis) == 1)
- axis = self.axis[0]
- if axis < 0:
- axis = len(self.input_shape) + axis
+ axis = self.make_axis_non_negative(self.axis[0])
assert (axis > 0)
size = self.input_shape[axis]
lw = x.lw.sum(dim=axis + 1, keepdim=True) / size
@@ -107,13 +102,12 @@ def infer_batch_dim(self, batch_size, *x):
return x[0]
class BoundReduceSum(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.axis = attr['axes'] if 'axes' in attr else None
self.keepdim = bool(attr['keepdims'])
self.use_default_ibp = True
- @Bound.save_io_shape
def forward(self, x):
if self.axis is not None:
return torch.sum(x, dim=self.axis, keepdim=self.keepdim)
@@ -143,14 +137,13 @@ def _bound_oneside(last_A):
return [(_bound_oneside(last_lA), _bound_oneside(last_uA))], 0, 0
def bound_forward(self, dim_in, x):
- assert self.keepdim
assert len(self.axis) == 1
- axis = self.axis[0]
- if axis < 0:
- axis = len(self.input_shape) + axis
- assert (axis > 0)
- lw, lb = x.lw.sum(dim=axis + 1, keepdim=True), x.lb.sum(dim=axis, keepdim=True)
- uw, ub = x.uw.sum(dim=axis + 1, keepdim=True), x.ub.sum(dim=axis, keepdim=True)
+ axis = self.make_axis_non_negative(self.axis[0])
+ assert axis > 0
+ lw = x.lw.sum(dim=axis + 1, keepdim=self.keepdim)
+ lb = x.lb.sum(dim=axis, keepdim=self.keepdim)
+ uw = x.uw.sum(dim=axis + 1, keepdim=self.keepdim)
+ ub = x.ub.sum(dim=axis, keepdim=self.keepdim)
return LinearBound(lw, lb, uw, ub)
def infer_batch_dim(self, batch_size, *x):
diff --git a/auto_LiRPA/operators/rnn.py b/auto_LiRPA/operators/rnn.py
index 4d59f27..7330bf6 100644
--- a/auto_LiRPA/operators/rnn.py
+++ b/auto_LiRPA/operators/rnn.py
@@ -3,12 +3,11 @@
class BoundRNN(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.complex = True
self.output_index = output_index
- @Bound.save_io_shape
def forward(self, x, weight_input, weight_recurrent, bias, sequence_length, initial_h):
assert (torch.sum(torch.abs(initial_h)) == 0)
@@ -17,7 +16,7 @@ def forward(self, x, weight_input, weight_recurrent, bias, sequence_length, init
class BoundRNNImpl(nn.Module):
def __init__(self, input_size, hidden_size,
- weight_input, weight_recurrent, bias, output_index, options, device):
+ weight_input, weight_recurrent, bias, output_index, options):
super().__init__()
self.input_size = input_size
@@ -38,7 +37,7 @@ def __init__(self, input_size, hidden_size,
def forward(self, x):
length = x.shape[0]
outputs = []
- hidden = torch.zeros(x.shape[1], self.hidden_size, device=self.device)
+ hidden = torch.zeros(x.shape[1], self.hidden_size).to(x)
for i in range(length):
hidden = self.cell(x[i, :], hidden)
outputs.append(hidden.unsqueeze(0))
@@ -52,7 +51,7 @@ def forward(self, x):
self.model = BoundRNNImpl(
self.input_size, self.hidden_size,
weight_input, weight_recurrent, bias,
- self.output_index, self.device)
+ self.output_index)
self.input = (x,)
return self.model(self.input)
\ No newline at end of file
diff --git a/auto_LiRPA/operators/shape.py b/auto_LiRPA/operators/shape.py
index 1f34b29..7b9c4b2 100644
--- a/auto_LiRPA/operators/shape.py
+++ b/auto_LiRPA/operators/shape.py
@@ -1,26 +1,60 @@
""" Shape operators """
from .base import *
+from ..patches import Patches, patches_to_matrix
+from .linear import BoundLinear
class BoundReshape(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ # It can be set to `view`, so that `view` instead of `reshape` will be used.
+ self.option = options.get('reshape', 'reshape')
- @Bound.save_io_shape
def forward(self, x, shape):
shape = list(shape)
for i in range(len(shape)):
if shape[i] == -1:
shape[i] = prod(x.shape) // int(prod(shape[:i]) * prod(shape[(i + 1):]))
self.shape = shape
- return x.reshape(shape)
+ if self.option == 'view':
+ return x.contiguous().view(shape)
+ else:
+ return x.reshape(shape)
def bound_backward(self, last_lA, last_uA, x, shape):
def _bound_oneside(A):
if A is None:
return None
- # A shape is (spec, batch, *node_shape)
- return A.reshape(A.shape[0], A.shape[1], *self.input_shape[1:])
-
+ if type(A) == Patches:
+ if type(self.inputs[0]) == BoundLinear:
+ # Save the shape and it will be converted to matrix in Linear layer.
+ return A.create_similar(input_shape=self.output_shape)
+ if A.unstable_idx is None:
+ patches = A.patches
+ # non-sparse: [batch, out_dim, out_c, out_H, out_W, out_dim, in_c, H, W]
+ # [batch, out_dim*out_c, out_H, out_W, out_dim*in_c, H, W]
+ patches = patches.reshape(
+ patches.shape[0],
+ patches.shape[1]*patches.shape[2], patches.shape[3], patches.shape[4],
+ patches.shape[5]*patches.shape[6], patches.shape[7], patches.shape[8])
+ # expected next_A shape [batch, spec, in_c, in_H , in_W].
+ next_A = patches_to_matrix(
+ patches, [
+ self.input_shape[0]*self.input_shape[1],
+ patches.shape[-3],
+ int(math.sqrt(self.input_shape[-1]//A.patches.shape[-3])),
+ int(math.sqrt(self.input_shape[-1]//A.patches.shape[-3]))],
+ A.stride, A.padding)
+ else:
+ # sparse: [spec, batch, in_c, patch_H, patch_W] (specs depends on the number of unstable neurons).
+ patches = A.patches
+ # expected next_A shape [batch, spec, input_c, in_H, in_W].
+ next_A = patches_to_matrix(patches, [self.input_shape[0]*self.input_shape[1], patches.shape[-3], int(math.sqrt(self.input_shape[-1]//patches.shape[-3])), int(math.sqrt(self.input_shape[-1]//patches.shape[-3]))], A.stride, A.padding, output_shape=A.output_shape, unstable_idx=A.unstable_idx)
+ # Reshape it to [batch, spec, *input_shape] (input_shape is the shape before Reshape operation).
+ next_A = next_A.reshape(A.shape[1], -1, *self.input_shape[1:])
+ return next_A.transpose(0,1)
+ else:
+ return A.reshape(A.shape[0], A.shape[1], *self.input_shape[1:])
+ #FIXME check reshape or view
return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0
def bound_forward(self, dim_in, x, shape):
@@ -32,15 +66,9 @@ def bound_forward(self, dim_in, x, shape):
return LinearBound(lw, lb, uw, ub)
def interval_propagate(self, *v):
- if Interval.use_relative_bounds(*v):
- return Interval(
- None, None,
- v[0].nominal.reshape(self.shape),
- v[0].lower_offset.reshape(self.shape),
- v[0].upper_offset.reshape(self.shape),
- ptb=v[0].ptb
- )
- return Interval.make_interval(v[0][0].reshape(*v[1][0]), v[0][1].reshape(*v[1][0]), v[0])
+ return Interval.make_interval(
+ self.forward(v[0][0], v[1][0]),
+ self.forward(v[0][1], v[1][0]), v[0])
def infer_batch_dim(self, batch_size, *x):
if x[0] == -1:
@@ -51,29 +79,48 @@ def infer_batch_dim(self, batch_size, *x):
self.input_shape, self.shape, x[0]
))
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ if isinstance(v[0], Tensor):
+ self.solver_vars = self.forward(*v)
+ return
+ gvar_array = np.array(v[0])
+ gvar_array = gvar_array.reshape(v[1].detach().cpu().numpy())[0]
+ self.solver_vars = gvar_array.tolist()
+
class BoundUnsqueeze(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.axes = attr['axes']
assert (len(self.axes) == 1)
self.axes = self.axes[0]
self.use_default_ibp = True
- @Bound.save_io_shape
def forward(self, x):
- if self.axes < 0:
- self.axes = len(self.input_shape) + self.axes + 1
- return x.unsqueeze(self.axes)
+ return x.unsqueeze(self.axes)
def bound_backward(self, last_lA, last_uA, x):
+ self.axes = self.make_axis_non_negative(self.axes, 'output')
if self.axes == 0:
- return last_lA, 0, last_uA, 0
+ # TODO: unsqueeze on batch dimension can be problematic.
+ return [(last_lA, last_uA)], 0, 0
else:
- return [(last_lA.squeeze(self.axes + 1) if last_lA is not None else None,
- last_uA.squeeze(self.axes + 1) if last_uA is not None else None)], 0, 0
+ if type(last_lA) == Patches:
+ lA = Patches(last_lA.patches.squeeze(self.axes - 5), last_lA.stride, last_lA.padding, last_lA.shape, last_lA.identity, last_lA.unstable_idx, last_lA.output_shape)
+ elif last_lA is not None:
+ lA = last_lA.squeeze(self.axes+1)
+ else:
+ lA = None
+ if type(last_uA) == Patches:
+ uA = Patches(last_uA.patches.squeeze(self.axes - 5), last_uA.stride, last_uA.padding, last_uA.shape, last_uA.identity, last_uA.unstable_idx, last_uA.output_shape)
+ elif last_uA is not None:
+ uA = last_uA.squeeze(self.axes+1)
+ else:
+ uA = None
+ return [(lA, uA)], 0, 0
def bound_forward(self, dim_in, x):
+ self.axes = self.make_axis_non_negative(self.axes, 'output')
if len(self.input_shape) == 0:
lw, lb = x.lw.unsqueeze(1), x.lb.unsqueeze(0)
uw, ub = x.uw.unsqueeze(1), x.ub.unsqueeze(0)
@@ -88,19 +135,20 @@ def infer_batch_dim(self, batch_size, *x):
elif self.axes > x[0]:
return x[0]
raise NotImplementedError
-
+
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ self.solver_vars = self.forward(v[0])
class BoundSqueeze(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.axes = attr['axes']
assert (len(self.axes) == 1)
self.axes = self.axes[0]
self.use_default_ibp = True
- @Bound.save_io_shape
def forward(self, x):
- return x.squeeze(self.axes)
+ return x.squeeze(self.axes)
def bound_backward(self, last_lA, last_uA, x):
assert (self.axes != 0)
@@ -119,12 +167,11 @@ def infer_batch_dim(self, batch_size, *x):
class BoundFlatten(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.use_default_ibp = True
self.axis = attr['axis']
- @Bound.save_io_shape
def forward(self, x):
return torch.flatten(x, self.axis)
@@ -133,22 +180,42 @@ def _bound_oneside(A):
if A is None:
return None
return A.reshape(A.shape[0], A.shape[1], *self.input_shape[1:])
-
return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0
+ def bound_dynamic_forward(self, x, max_dim=None, offset=0):
+ w = torch.flatten(x.lw, self.axis + 1)
+ b = torch.flatten(x.lb, self.axis)
+ x_L = torch.flatten(x.x_L, self.axis)
+ x_U = torch.flatten(x.x_U, self.axis)
+ return LinearBound(w, b, w, b, x_L=x_L, x_U=x_U, tot_dim=x.tot_dim)
+
+ def bound_forward(self, dim_in, x):
+ self.axis = self.make_axis_non_negative(self.axis)
+ assert self.axis > 0
+ return LinearBound(
+ torch.flatten(x.lw, self.axis + 1),
+ torch.flatten(x.lb, self.axis),
+ torch.flatten(x.uw, self.axis + 1),
+ torch.flatten(x.ub, self.axis),
+ )
+
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ # e.g., v[0] input shape (16, 8, 8) => output shape (1024,)
+ self.solver_vars = np.array(v[0]).reshape(-1).tolist()
+ model.update()
+
+
class BoundConcat(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.axis = attr['axis']
self.IBP_rets = None
- @Bound.save_io_shape
def forward(self, *x): # x is a list of tensors
- x = [(item if isinstance(item, torch.Tensor) else torch.tensor(item)) for item in x]
+ x = [(item if isinstance(item, Tensor) else torch.tensor(item)) for item in x]
self.input_size = [item.shape[self.axis] for item in x]
- if self.axis < 0:
- self.axis = x[0].ndim + self.axis
+ self.axis = self.make_axis_non_negative(self.axis)
return torch.cat(x, dim=int(self.axis))
def interval_propagate(self, *v):
@@ -169,15 +236,6 @@ def interval_propagate(self, *v):
all_inf = all(map(lambda x: x is None or x == np.inf, norms))
all_2 = all(map(lambda x: x is None or x == 2, norms))
- if Interval.use_relative_bounds(*v):
- assert all_inf # Only LINF supported for now
- return Interval(
- None, None,
- self.forward(*[_v.nominal for _v in v]),
- self.forward(*[_v.lower_offset for _v in v]),
- self.forward(*[_v.upper_offset for _v in v]),
- )
-
h_L = [_v[0] for _v in v]
h_U = [_v[1] for _v in v]
if all_inf:
@@ -195,14 +253,22 @@ def interval_propagate(self, *v):
raise RuntimeError("BoundConcat does not support inputs with norm {}".format(norms))
def bound_backward(self, last_lA, last_uA, *x):
- if self.axis < 0:
- self.axis = len(self.output_shape) + self.axis
- assert (self.axis > 0)
+ self.axis = self.make_axis_non_negative(self.axis, 'output')
+ assert self.axis > 0
def _bound_oneside(last_A):
if last_A is None:
return None
- return torch.split(last_A, self.input_size, dim=self.axis + 1)
+ if isinstance(last_A, torch.Tensor):
+ return torch.split(last_A, self.input_size, dim=self.axis + 1)
+ elif isinstance(last_A, Patches):
+ assert len(self.input_shape) == 4 and self.axis == 1, "Split channel dimension is supported; others are unimplemented."
+ # Patches shape can be [out_c, batch, out_h, out_w, in_c, patch_h, patch_w]
+ # Or [spec, batch, in_c, patch_h, patch_w] (sparse)
+ new_patches = torch.split(last_A.patches, self.input_size, dim=-3) # split the in_c dimension is easy.
+ return [last_A.create_similar(p) for p in new_patches]
+ else:
+ raise RuntimeError(f'Unsupported type for last_A: {type(last_A)}')
uA = _bound_oneside(last_uA)
lA = _bound_oneside(last_lA)
@@ -213,8 +279,7 @@ def _bound_oneside(last_A):
return [(lA[i], uA[i]) for i in range(len(lA))], 0, 0
def bound_forward(self, dim_in, *x):
- if self.axis < 0:
- self.axis = x[0].lb.ndim + self.axis
+ self.axis = self.make_axis_non_negative(self.axis)
assert (self.axis == 0 and not self.from_input or self.from_input)
lw = torch.cat([item.lw for item in x], dim=self.axis + 1)
lb = torch.cat([item.lb for item in x], dim=self.axis)
@@ -222,19 +287,25 @@ def bound_forward(self, dim_in, *x):
ub = torch.cat([item.ub for item in x], dim=self.axis)
return LinearBound(lw, lb, uw, ub)
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ self.solver_vars = self.forward(*v)
+
def infer_batch_dim(self, batch_size, *x):
assert np.min(x) == np.max(x)
assert x[0] != self.axis
return x[0]
+
+BoundConcatFromSequence = BoundConcat
+
class BoundShape(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
- self.use_default_ibp = True
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
+ self.use_default_ibp = True
@staticmethod
def shape(x):
- return x.shape if isinstance(x, torch.Tensor) else torch.tensor(x).shape
+ return x.shape if isinstance(x, Tensor) else torch.tensor(x).shape
def forward(self, x):
self.from_input = False
@@ -247,59 +318,88 @@ def infer_batch_dim(self, batch_size, *x):
return -1
def interval_propagate(self, *v):
- if Interval.use_relative_bounds(*v):
- shape = self.forward(v[0].nominal)
- if not isinstance(shape, torch.Tensor):
- shape = torch.tensor(shape, device=self.device)
- return Interval(
- None, None,
- shape, torch.zeros_like(shape), torch.zeros_like(shape)
- )
-
return super().interval_propagate(*v)
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ if not isinstance(v[0], Tensor):
+ # e.g., v[0] input shape (8, 7, 7) => output its shape (1, 8, 7, 7)
+ gvars_array = np.array(v[0])
+ self.solver_vars = torch.tensor(np.expand_dims(gvars_array, axis=0).shape).long()
+ else:
+ self.solver_vars = torch.tensor(self.forward(v[0])).long()
+
class BoundGather(Bound):
- def __init__(self, input_name, name, ori_name, attr, x, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, x, output_index, options, device)
+ def __init__(self, attr, x, output_index, options):
+ super().__init__(attr, x, output_index, options)
self.axis = attr['axis'] if 'axis' in attr else 0
- self.nonlinear = False # input shape required
- @Bound.save_io_shape
def forward(self, x, indices):
self.indices = indices
+ if self.axis == -1:
+ self.axis = len(x.shape) - 1
x = x.to(self.indices.device) # BoundShape.shape() will return value on cpu only
if indices.ndim == 0:
- # `index_select` requires `indices` to be a 1-D tensor
return torch.index_select(x, dim=self.axis, index=indices).squeeze(self.axis)
elif self.axis == 0:
return torch.index_select(x, dim=self.axis, index=indices.reshape(-1)) \
.reshape(*indices.shape, x.shape[-1])
- raise ValueError(
- 'Unsupported shapes in Gather: data {}, indices {}, axis {}'.format(x.shape, indices.shape, self.axis))
+ elif self.indices.ndim == 1:
+ # `index_select` requires `indices` to be a 1-D tensor
+ return torch.index_select(x, dim=self.axis, index=indices)
+
+ raise ValueError('Unsupported shapes in Gather: data {}, indices {}, axis {}'.format(x.shape, indices.shape, self.axis))
def bound_backward(self, last_lA, last_uA, x, indices):
assert self.from_input
- assert self.indices.ndim == 0 # TODO
- def _bound_oneside(A):
- if A is None:
- return None
- assert (self.indices.ndim == 0)
-
- A = A.unsqueeze(self.axis + 1)
- idx = int(self.indices)
+ def _expand_A_with_zeros(A, axis, idx, max_axis_size):
+ # Need to recreate A with three parts: before the gathered element, gathered element, and after gathered element.
tensors = []
if idx > 0:
shape_pre = list(A.shape)
- shape_pre[self.axis + 1] *= idx
- tensors.append(torch.zeros(shape_pre, device=self.device))
+ shape_pre[axis] *= idx
+ # Create the same shape as A, except for the dimension to be gathered.
+ tensors.append(torch.zeros(shape_pre, device=A.device))
+ # The gathered element itself, in the middle.
tensors.append(A)
- if self.input_shape[self.axis] - idx - 1 > 0:
+ if max_axis_size - idx - 1 > 0:
shape_next = list(A.shape)
- shape_next[self.axis + 1] *= self.input_shape[self.axis] - idx - 1
- tensors.append(torch.zeros(shape_next, device=self.device))
- return torch.cat(tensors, dim=self.axis + 1)
+ shape_next[axis] *= max_axis_size - idx - 1
+ # Create the rest part of A.
+ tensors.append(torch.zeros(shape_next, device=A.device))
+ # Concatenate all three parts together.
+ return torch.cat(tensors, dim=axis)
+
+ def _bound_oneside(A):
+ if A is None:
+ return None
+
+ if isinstance(A, torch.Tensor):
+ if self.indices.ndim == 0:
+ A = A.unsqueeze(self.axis + 1)
+ idx = int(self.indices)
+ return _expand_A_with_zeros(A, self.axis + 1, idx, self.input_shape[self.axis])
+ else:
+ shape = list(A.shape)
+ final_A = torch.zeros(*shape[:self.axis + 1], self.input_shape[self.axis], *shape[self.axis + 2:], device=A.device)
+ idx = self.indices.view([*[1]*(self.axis+1), -1, *[1]*len(shape[self.axis + 2:])])
+ idx = idx.repeat([*A.shape[:self.axis+1], 1, *A.shape[self.axis+2:]])
+ final_A.scatter_(dim=self.axis+1, index=idx, src=A)
+ return final_A
+ elif isinstance(A, Patches):
+ if self.indices.ndim == 0:
+ idx = int(self.indices)
+ assert len(self.input_shape) == 4 and self.axis == 1, "Gather is only supported on the channel dimension for Patches mode."
+ # For gather in the channel dimension, we only need to deal with the in_c dimension (-3) in patches.
+ patches = A.patches
+ # -3 is the in_c dimension.
+ new_patches = _expand_A_with_zeros(patches, axis=-3, idx=idx, max_axis_size=self.input_shape[self.axis])
+ return A.create_similar(new_patches)
+ else:
+ raise NotImplementedError
+ else:
+ raise ValueError(f'Unknown last_A type {type(A)}')
return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0
@@ -321,15 +421,6 @@ def bound_forward(self, dim_in, x, indices):
def interval_propagate(self, *v):
assert not self.is_input_perturbed(1)
-
- if Interval.use_relative_bounds(*v):
- return Interval(
- None, None,
- self.forward(v[0].nominal, v[1].nominal),
- self.forward(v[0].lower_offset, v[1].nominal),
- self.forward(v[0].upper_offset, v[1].nominal)
- )
-
return self.forward(v[0][0], v[1][0]), self.forward(v[0][1], v[1][0])
def infer_batch_dim(self, batch_size, *x):
@@ -339,27 +430,22 @@ def infer_batch_dim(self, batch_size, *x):
else:
return x[1]
- def infer_batch_dim(self, batch_size, *x):
- if x[0] != -1:
- assert self.axis != x[0]
- return x[0]
- else:
- return x[1]
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ self.solver_vars = self.forward(v[0], v[1])
class BoundGatherElements(Bound):
- def __init__(self, input_name, name, ori_name, attr, input, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, input, output_index, options, device)
+ def __init__(self, attr, input, output_index, options):
+ super().__init__(attr, input, output_index, options)
self.axis = attr['axis']
- @Bound.save_io_shape
def forward(self, x, index):
self.index = index
return torch.gather(x, dim=self.axis, index=index)
def bound_backward(self, last_lA, last_uA, x, index):
assert self.from_input
-
+
dim = self._get_dim()
def _bound_oneside(last_A):
@@ -377,15 +463,6 @@ def _bound_oneside(last_A):
def interval_propagate(self, *v):
assert not self.is_input_perturbed(1)
-
- if Interval.use_relative_bounds(*v):
- return Interval(
- None, None,
- self.forward(v[0].nominal, v[1].nominal),
- self.forward(v[0].lower_offset, v[1].nominal),
- self.forward(v[0].upper_offset, v[1].nominal)
- )
-
return self.forward(v[0][0], v[1][0]), \
self.forward(v[0][1], v[1][1])
@@ -401,7 +478,7 @@ def bound_forward(self, dim_in, x, index):
def infer_batch_dim(self, batch_size, *x):
assert self.axis != x[0]
return x[0]
-
+
def _get_dim(self):
dim = self.axis
if dim < 0:
@@ -409,16 +486,15 @@ def _get_dim(self):
return dim
class BoundTranspose(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.perm = attr['perm']
self.perm_inv_inc_one = [-1] * (len(self.perm) + 1)
self.perm_inv_inc_one[0] = 0
for i in range(len(self.perm)):
self.perm_inv_inc_one[self.perm[i] + 1] = i + 1
- self.use_default_ibp = True
+ self.use_default_ibp = True
- @Bound.save_io_shape
def forward(self, x):
return x.permute(*self.perm)
@@ -441,6 +517,9 @@ def bound_forward(self, dim_in, x):
return LinearBound(lw, lb, uw, ub)
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ self.solver_vars = self.forward(*v)
+
def infer_batch_dim(self, batch_size, *x):
if x[0] == -1:
return -1
@@ -449,21 +528,14 @@ def infer_batch_dim(self, batch_size, *x):
class BoundSlice(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.start = attr["starts"][0] if "starts" in attr else None
self.end = attr["ends"][0] if "ends" in attr else None
self.axes = attr["axes"][0] if "axes" in attr else None
self.use_default_ibp = False
- # Older Pytorch version only passes steps as input.
- @Bound.save_io_shape
- def forward(self, x, start=None, end=None, axes=None, steps=1):
- start = self.start if start is None else start
- end = self.end if end is None else end
- axes = self.axes if axes is None else axes
- assert (steps == 1 or steps == -1) and axes == int(axes) and start == int(start) and end == int(end)
- shape = x.shape if isinstance(x, torch.Tensor) else [len(x)]
+ def _fixup_params(self, shape, start, end, axes, steps):
if start < 0:
start += shape[axes]
if end < 0:
@@ -471,19 +543,32 @@ def forward(self, x, start=None, end=None, axes=None, steps=1):
end = 0 # only possible when step == -1
else:
end += shape[axes]
- if steps == -1:
+ if steps == -1:
start, end = end, start + 1 # TODO: more test more negative step size.
end = min(end, shape[axes])
+ return start, end
+
+ # Older Pytorch version only passes steps as input.
+ def forward(self, x, start=None, end=None, axes=None, steps=1):
+ start = self.start if start is None else start
+ end = self.end if end is None else end
+ axes = self.axes if axes is None else axes
+ assert (steps == 1 or steps == -1) and axes == int(axes) and start == int(start) and end == int(end)
+ shape = x.shape if isinstance(x, Tensor) else [len(x)]
+ start, end = self._fixup_params(shape, start, end, axes, steps)
final = torch.narrow(x, dim=int(axes), start=int(start), length=int(end - start))
if steps == -1:
final = torch.flip(final, dims=tuple(axes))
return final
-
- def interval_propagate(self, *v):
+
+ def interval_propagate(self, *v):
lb = tuple(map(lambda x:x[0],v))
ub = tuple(map(lambda x:x[1],v))
return Interval.make_interval(self.forward(*lb), self.forward(*ub))
+ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"):
+ self.solver_vars = self.forward(*v)
+
def infer_batch_dim(self, batch_size, *x):
if x[0] == -1:
return -1
@@ -491,12 +576,71 @@ def infer_batch_dim(self, batch_size, *x):
assert self.axes != x[0]
return x[0]
+ def bound_backward(self, last_lA, last_uA, *x):
+ def _bound_oneside(A, start, end, axes, steps):
+ if A is None:
+ return None
+ if isinstance(A, torch.Tensor):
+ # Reuse the batch and spec dimension of A, and replace other shapes with input.
+ A_shape = A.shape[:2] + self.input_shape[1:]
+ new_A = torch.zeros(size=A_shape, device=A.device, requires_grad=A.requires_grad)
+ # Fill part of the new_A based on start, end, axes and steps.
+ # Skip the spec dimension at the front (axes + 1).
+ dim = axes if axes < 0 else axes + 1
+ indices = torch.arange(start, end, device=A.device)
+ new_A = torch.index_copy(new_A, dim=dim, index=indices, source=A)
+ elif isinstance(A, Patches):
+ assert A.unstable_idx is None
+ assert len(self.input_shape) == 4 and axes == 1, "Slice is only supported on channel dimension."
+ patches = A.patches
+ # patches shape is [out_c, batch, out_h, out_w, in_c, patch_h, patch_w].
+ new_patches_shape = patches.shape[:4] + (self.input_shape[1], ) + patches.shape[-2:]
+ new_patches = torch.zeros(size=new_patches_shape, device=patches.device, requires_grad=patches.requires_grad)
+ indices = torch.arange(start, end, device=patches.device)
+ new_patches = torch.index_copy(new_patches, dim=-3, index=indices, source=patches)
+ # Only the in_c dimension is changed.
+ new_A = A.create_similar(new_patches)
+ else:
+ raise ValueError(f'Unsupport A type {type(A)}')
+ return new_A
+
+ start, end, axes = x[1].value.item(), x[2].value.item(), x[3].value.item()
+ steps = x[4].value.item() if len(x) == 5 else 1 # If step is not specified, it is 1.
+ # Other step size untested, do not enable for now.
+ assert steps == 1 and axes == int(axes) and start == int(start) and end == int(end)
+ start, end = self._fixup_params(self.input_shape, start, end, axes, steps)
+ # Find the original shape of A.
+ lA = _bound_oneside(last_lA, start, end, axes, steps)
+ uA = _bound_oneside(last_uA, start, end, axes, steps)
+ return [(lA, uA), (None, None), (None, None), (None, None), (None, None)], 0, 0
+
+ def bound_forward(self, dim_in, *inputs):
+ assert len(inputs) == 5 or len(inputs) == 4
+ start = inputs[1].lb.item()
+ end = inputs[2].lb.item()
+ axis = self.make_axis_non_negative(inputs[3].lb.item())
+ assert axis > 0, "Slicing along the batch dimension is not supported yet"
+ steps = inputs[4].lb.item() if len(inputs) == 5 else 1 # If step is not specified, it is 1.
+ assert steps in [1, -1]
+ x = inputs[0]
+ shape = x.lb.shape
+ start, end = self._fixup_params(shape, start, end, axis, steps)
+ lw = torch.narrow(x.lw, dim=axis+1, start=start, length=end - start)
+ uw = torch.narrow(x.uw, dim=axis+1, start=start, length=end - start)
+ lb = torch.narrow(x.lb, dim=axis, start=start, length=end - start)
+ ub = torch.narrow(x.ub, dim=axis, start=start, length=end - start)
+ if steps == -1:
+ lw = torch.flip(lw, dims=tuple(axis+1))
+ uw = torch.flip(uw, dims=tuple(axis+1))
+ lb = torch.flip(lb, dims=tuple(axis))
+ ub = torch.flip(ub, dims=tuple(axis))
+ return LinearBound(lw, lb, uw, ub)
+
class BoundExpand(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
- @Bound.save_io_shape
def forward(self, x, y):
y = y.clone()
assert y.ndim == 1
@@ -518,18 +662,19 @@ def infer_batch_dim(self, batch_size, *x):
class BoundSplit(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.axis = attr['axis']
self.split = attr['split']
self.use_default_ibp = True
- @Bound.save_io_shape
def forward(self, x):
+ if self.axis == -1:
+ self.axis = len(x.shape) - 1
return torch.split(x, self.split, dim=self.axis)[self.output_index]
def bound_backward(self, last_lA, last_uA, x):
- assert (self.axis > 0)
+ assert self.axis > 0
pre = sum(self.split[:self.output_index])
suc = sum(self.split[(self.output_index + 1):])
diff --git a/auto_LiRPA/operators/softmax.py b/auto_LiRPA/operators/softmax.py
index 21ad6b0..c8fdd1c 100644
--- a/auto_LiRPA/operators/softmax.py
+++ b/auto_LiRPA/operators/softmax.py
@@ -7,7 +7,6 @@ def __init__(self, axis):
self.axis = axis
assert self.axis == int(self.axis)
- @Bound.save_io_shape
def forward(self, x):
max_x = torch.max(x, dim=self.axis).values
x = torch.exp(x - max_x.unsqueeze(self.axis))
@@ -16,8 +15,8 @@ def forward(self, x):
# The `option != 'complex'` case is not used in the auto_LiRPA main paper.
class BoundSoftmax(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ def __init__(self, attr, inputs, output_index, options):
+ super().__init__(attr, inputs, output_index, options)
self.axis = attr['axis']
self.option = options.get('softmax', 'complex')
if self.option == 'complex':
@@ -25,7 +24,6 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio
else:
self.max_input = 30
- @Bound.save_io_shape
def forward(self, x):
assert self.axis == int(self.axis)
if self.option == 'complex':
diff --git a/auto_LiRPA/operators/solver_utils.py b/auto_LiRPA/operators/solver_utils.py
new file mode 100644
index 0000000..79a02b2
--- /dev/null
+++ b/auto_LiRPA/operators/solver_utils.py
@@ -0,0 +1,11 @@
+class DummyGurobipyClass:
+ """A dummy class with error message when gurobi is not installed."""
+ def __getattr__(self, attr):
+ def _f(*args, **kwargs):
+ raise RuntimeError(f"method {attr} not available because gurobipy module was not built.")
+ return _f
+
+try:
+ import gurobipy as grb
+except ModuleNotFoundError:
+ grb = DummyGurobipyClass()
\ No newline at end of file
diff --git a/auto_LiRPA/optimized_bounds.py b/auto_LiRPA/optimized_bounds.py
new file mode 100644
index 0000000..7ec1ffd
--- /dev/null
+++ b/auto_LiRPA/optimized_bounds.py
@@ -0,0 +1,1046 @@
+import time
+import os
+import warnings
+from collections import OrderedDict
+from contextlib import ExitStack
+
+import torch
+from torch import optim
+from .cuda_utils import double2float
+
+
+def _set_alpha(optimizable_activations, parameters, alphas, lr):
+ """
+ Set best_alphas, alphas and parameters list
+ """
+ for node in optimizable_activations:
+ alphas.extend(list(node.alpha.values()))
+ node.opt_start()
+ # Alpha has shape (2, output_shape, batch_dim, node_shape)
+ parameters.append({'params': alphas, 'lr': lr, 'batch_dim': 2})
+ # best_alpha is a dictionary of dictionary. Each key is the alpha variable
+ # for one relu layer, and each value is a dictionary contains all relu
+ # layers after that layer as keys.
+ best_alphas = OrderedDict()
+ for m in optimizable_activations:
+ best_alphas[m.name] = {}
+ for alpha_m in m.alpha:
+ best_alphas[m.name][alpha_m] = m.alpha[alpha_m].detach().clone()
+ # We will directly replace the dictionary for each relu layer after
+ # optimization, so the saved alpha might not have require_grad=True.
+ m.alpha[alpha_m].requires_grad_()
+
+ return best_alphas
+
+
+def _set_beta(
+ self, relus, optimizable_activations, single_node_split,
+ enable_opt_interm_bounds, betas, opt_coeffs, parameters,
+ lr_coeffs, opt_bias, lr_beta, lr_cut_beta, cutter, dense_coeffs_mask):
+ """
+ Set betas, best_betas, coeffs, dense_coeffs_mask, best_coeffs, biases
+ and best_biases.
+ """
+ coeffs = best_coeffs = biases = best_biases = None
+ if len(relus) != len(optimizable_activations):
+ warnings.warn(
+ 'Only relu split is supported so far, this model contains other '
+ 'optimizable activations that may not apply split.')
+
+ if single_node_split:
+ for node in relus:
+ if enable_opt_interm_bounds and node.sparse_beta is not None:
+ for key in node.sparse_beta.keys():
+ if node.sparse_beta[key] is not None:
+ betas.append(node.sparse_beta[key])
+ else:
+ if node.sparse_beta is not None:
+ betas.append(node.sparse_beta)
+ else:
+ betas = self.beta_params + self.single_beta_params
+ if opt_coeffs:
+ coeffs = [dense_coeffs['dense']
+ for dense_coeffs in self.split_dense_coeffs_params
+ ] + self.coeffs_params
+ dense_coeffs_mask = [dense_coeffs['mask']
+ for dense_coeffs in self.split_dense_coeffs_params]
+ parameters.append({'params': coeffs, 'lr': lr_coeffs})
+ best_coeffs = [coeff.detach().clone() for coeff in coeffs]
+ if opt_bias:
+ biases = self.bias_params
+ parameters.append({'params': biases, 'lr': lr_coeffs})
+ best_biases = [bias.detach().clone() for bias in biases]
+
+ # Beta has shape (batch, max_splits_per_layer)
+ parameters.append({'params': betas, 'lr': lr_beta, 'batch_dim': 0})
+
+ if self.cut_used:
+ # also need to optimize cut betas
+ parameters.append({'params': self.cut_beta_params,
+ 'lr': lr_cut_beta, 'batch_dim': 0})
+ betas = betas + self.cut_beta_params
+
+ if enable_opt_interm_bounds and betas:
+ best_betas = OrderedDict()
+ for m in optimizable_activations:
+ best_betas[m.name] = {}
+ for beta_m, beta in m.sparse_beta.items():
+ best_betas[m.name][beta_m] = beta.detach().clone()
+ if self.cut_used:
+ best_betas['cut'] = []
+ for general_betas in self.cut_beta_params:
+ best_betas['cut'].append(general_betas.detach().clone())
+ else:
+ best_betas = [b.detach().clone() for b in betas]
+
+ if self.cut_used and getattr(cutter, 'opt', False):
+ parameters.append(cutter.get_parameters())
+
+ return (
+ betas, best_betas, coeffs, dense_coeffs_mask, best_coeffs, biases,
+ best_biases)
+
+
+def _save_ret_first_time(bounds, fill_value, x, best_ret):
+ """
+ Save results at the first iteration to best_ret
+ """
+ if bounds is not None:
+ best_bounds = torch.full_like(
+ bounds, fill_value=fill_value, device=x[0].device, dtype=x[0].dtype)
+ else:
+ best_bounds = None
+
+ if bounds is not None:
+ best_ret.append(bounds.detach().clone())
+ else:
+ best_ret.append(None)
+
+ return best_bounds
+
+
+@torch.no_grad()
+def _get_preserve_mask(
+ decision_thresh, ret_l, preserve_mask, multi_spec_keep_func):
+ """
+ Get preserve mask by decision_thresh to filter out the satisfied bounds.
+ """
+ if (isinstance(decision_thresh, torch.Tensor)
+ and decision_thresh.numel() > 1):
+ if decision_thresh.shape[-1] == 1:
+ now_preserve_mask = (
+ ret_l <= decision_thresh[preserve_mask]
+ ).view(-1).nonzero().view(-1)
+ else:
+ now_preserve_mask = multi_spec_keep_func(
+ ret_l <= decision_thresh[preserve_mask]).nonzero().view(-1)
+ else:
+ if decision_thresh.shape[-1] == 1:
+ now_preserve_mask = (
+ ret_l <= decision_thresh).view(-1).nonzero().view(-1)
+ else:
+ now_preserve_mask = multi_spec_keep_func(
+ ret_l <= decision_thresh).nonzero().view(-1)
+
+ return now_preserve_mask
+
+
+def _recover_bounds_to_full_batch(
+ ret, decision_thresh, epsilon_over_decision_thresh, original_size,
+ preserve_mask, loss_reduction_func):
+ """
+ Recover lower and upper bounds to full batch size so that later we can
+ directly update using the full batch size of l and u.
+ """
+ if ret is not None:
+ if (isinstance(decision_thresh, torch.Tensor)
+ and decision_thresh.numel() > 1):
+ full_ret = (decision_thresh.clone().to(ret.device).type(ret.dtype)
+ + epsilon_over_decision_thresh)
+ else:
+ num_decision_thresh = decision_thresh
+ if isinstance(num_decision_thresh, torch.Tensor):
+ num_decision_thresh = num_decision_thresh.item()
+ full_ret = torch.full(
+ (original_size,) + tuple(ret.shape[1:]),
+ fill_value=num_decision_thresh + epsilon_over_decision_thresh,
+ device=ret.device, dtype=ret.dtype)
+ full_ret[preserve_mask] = ret
+ if full_ret.shape[1] > 1:
+ full_reduced_ret = loss_reduction_func(full_ret)
+ else:
+ full_reduced_ret = full_ret
+ else:
+ full_ret = full_reduced_ret = None
+
+ return full_ret, full_reduced_ret
+
+
+@torch.no_grad()
+def _prune_bounds_by_mask(
+ now_preserve_mask, decision_thresh, ret_l, ret_u, ret, preserve_mask,
+ epsilon_over_decision_thresh, original_size, loss_reduction_func,
+ beta, intermediate_beta_enabled,
+ fix_intermediate_layer_bounds, intermediate_layer_bounds,
+ aux_reference_bounds, partial_intermediate_layer_bounds,
+ pre_prune_size):
+ """
+ Prune bounds by given now_preserve_mask.
+ """
+ full_ret_l, full_l = _recover_bounds_to_full_batch(
+ ret_l, decision_thresh, epsilon_over_decision_thresh,
+ original_size, preserve_mask, loss_reduction_func)
+
+ full_ret_u, full_u = _recover_bounds_to_full_batch(
+ ret_u, decision_thresh, epsilon_over_decision_thresh,
+ original_size, preserve_mask, loss_reduction_func)
+
+ full_ret = (full_ret_l, full_ret_u) + ret[2:]
+
+ if beta and intermediate_beta_enabled:
+ # prune the partial_intermediate_layer_bounds
+ interval_to_prune = partial_intermediate_layer_bounds
+ elif fix_intermediate_layer_bounds:
+ interval_to_prune = intermediate_layer_bounds
+ else:
+ interval_to_prune = None
+ if interval_to_prune is not None:
+ for k, v in interval_to_prune.items():
+ interm_interval_l, interm_interval_r = v[0], v[1]
+ if interm_interval_l.shape[0] == pre_prune_size:
+ # the first dim is batch size and matches preserve mask
+ interm_interval_l = interm_interval_l[now_preserve_mask]
+ if interm_interval_r.shape[0] == pre_prune_size:
+ # the first dim is batch size and matches preserve mask
+ interm_interval_r = interm_interval_r[now_preserve_mask]
+ interval_to_prune[k] = [interm_interval_l, interm_interval_r]
+
+ if aux_reference_bounds is not None:
+ for k in aux_reference_bounds:
+ aux_ref_l, aux_ref_r = aux_reference_bounds[k]
+ if aux_ref_l.shape[0] == pre_prune_size:
+ # the first dim is batch size and matches the preserve mask
+ aux_ref_l = aux_ref_l[now_preserve_mask]
+ if aux_ref_r.shape[0] == pre_prune_size:
+ # the first dim is batch size and matches the preserve mask
+ aux_ref_r = aux_ref_r[now_preserve_mask]
+ aux_reference_bounds[k] = [aux_ref_l, aux_ref_r]
+
+ # update the global mask here for possible next iteration
+ preserve_mask_next = preserve_mask[now_preserve_mask]
+
+ return full_l, full_ret_l, full_u, full_ret_u, full_ret, preserve_mask_next
+
+
+@torch.no_grad()
+def _prune_x(x, now_preserve_mask):
+ """
+ Prune x by given now_preserve_mask.
+ """
+ x = list(x)
+ pre_prune_size = x[0].shape[0]
+ x[0].data = x[0][now_preserve_mask].data
+ if hasattr(x[0], 'ptb'):
+ if x[0].ptb.x_L is not None:
+ x[0].ptb.x_L = x[0].ptb.x_L[now_preserve_mask]
+ if x[0].ptb.x_U is not None:
+ x[0].ptb.x_U = x[0].ptb.x_U[now_preserve_mask]
+ x = tuple(x)
+
+ return x, pre_prune_size
+
+
+def _to_float64(self, C, x, aux_reference_bounds, intermediate_layer_bounds):
+ """
+ Transfer variables to float64 only in the last iteration to help alleviate
+ floating point error.
+ """
+ self.to(torch.float64)
+ C = C.to(torch.float64)
+ x = self._to(x, torch.float64)
+ # best_intermediate_bounds is linked to aux_reference_bounds!
+ # we only need call .to() for one of them
+ self._to(aux_reference_bounds, torch.float64, inplace=True)
+ intermediate_layer_bounds = self._to(
+ intermediate_layer_bounds, torch.float64)
+
+ return C, x, intermediate_layer_bounds
+
+
+def _to_default_dtype(
+ self, x, total_loss, full_ret, ret, best_intermediate_bounds, return_A):
+ """
+ Switch back to default precision from float64 typically to adapt to
+ afterwards operations.
+ """
+ total_loss = total_loss.to(torch.get_default_dtype())
+ self.to(torch.get_default_dtype())
+ x[0].to(torch.get_default_dtype())
+ full_ret = list(full_ret)
+ if isinstance(ret[0], torch.Tensor):
+ # round down lower bound
+ full_ret[0] = double2float(full_ret[0], 'down')
+ if isinstance(ret[1], torch.Tensor):
+ # round up upper bound
+ full_ret[1] = double2float(full_ret[1], 'up')
+ for _k, _v in best_intermediate_bounds.items():
+ _v[0] = double2float(_v[0], 'down')
+ _v[1] = double2float(_v[1], 'up')
+ best_intermediate_bounds[_k] = _v
+ if return_A:
+ full_ret[2] = self._to(full_ret[2], torch.get_default_dtype())
+
+ return total_loss, x, full_ret
+
+
+def _update_best_ret(
+ full_ret_bound, best_ret_bound, full_ret, best_ret, need_update, idx):
+ """Update best_ret_bound and best_ret by comparing with new results."""
+ assert idx in [0, 1], (
+ '0 means updating lower bound, 1 means updating upper bound')
+ if idx == 0:
+ idx_mask = (full_ret_bound > best_ret_bound).any(dim=1).view(-1)
+ else:
+ idx_mask = (full_ret_bound < best_ret_bound).any(dim=1).view(-1)
+
+ improved_idx = None
+ if idx_mask.any():
+ need_update = True
+ # we only pick up the results improved in a batch
+ improved_idx = idx_mask.nonzero(as_tuple=True)[0]
+ # total_loss = total_loss.to(best_ret_l)
+ if idx == 0:
+ best_ret_bound[improved_idx] = torch.maximum(
+ full_ret_bound[improved_idx], best_ret_bound[improved_idx])
+ if full_ret[idx] is not None:
+ best_ret[idx][improved_idx] = torch.maximum(
+ full_ret[idx][improved_idx], best_ret[idx][improved_idx])
+ else:
+ best_ret_bound[improved_idx] = torch.minimum(
+ full_ret_bound[improved_idx], best_ret_bound[improved_idx])
+ if full_ret[idx] is not None:
+ best_ret[idx][improved_idx] = torch.minimum(
+ full_ret[idx][improved_idx], best_ret[idx][improved_idx])
+
+ return best_ret_bound, best_ret, idx_mask, improved_idx, need_update
+
+
+def _update_optimizable_activations(
+ optimizable_activations, pruning_in_iteration,
+ intermediate_layer_bounds, fix_intermediate_layer_bounds,
+ best_intermediate_bounds, idx, local_idx, alpha, best_alphas):
+ """
+ Update bounds and alpha of optimizable_activations.
+ """
+ for node in optimizable_activations:
+ reference_idx = local_idx if pruning_in_iteration else idx
+ # Update best intermediate layer bounds only when they are optimized.
+ # If they are already fixed in intermediate_layer_bounds, then do
+ # nothing.
+ if (intermediate_layer_bounds is None
+ or node.inputs[0].name not in intermediate_layer_bounds
+ or not fix_intermediate_layer_bounds):
+ best_intermediate_bounds[node.name][0][idx] = torch.max(
+ best_intermediate_bounds[node.name][0][idx],
+ node.inputs[0].lower[reference_idx])
+ best_intermediate_bounds[node.name][1][idx] = torch.min(
+ best_intermediate_bounds[node.name][1][idx],
+ node.inputs[0].upper[reference_idx])
+
+ if alpha:
+ # Each alpha has shape (2, output_shape, batch, *shape) for ReLU.
+ # For other activation function this can be different.
+ for alpha_m in node.alpha:
+ if node.alpha_batch_dim == 2:
+ best_alphas[node.name][alpha_m][:, :,
+ idx] = node.alpha[alpha_m][:, :, idx]
+ elif node.alpha_batch_dim == 3:
+ best_alphas[node.name][alpha_m][:, :, :,
+ idx] = node.alpha[alpha_m][:, :, :, idx]
+ else:
+ raise ValueError(
+ f'alpha_batch_dim={node.alpha_batch_dim} must be set '
+ 'to 2 or 3 in BoundOptimizableActivation')
+
+
+def _update_best_beta(
+ self, enable_opt_interm_bounds, betas, optimizable_activations,
+ best_betas, idx):
+ """
+ Update best beta by given idx.
+ """
+ if enable_opt_interm_bounds and betas:
+ for node in optimizable_activations:
+ for key in node.sparse_beta.keys():
+ best_betas[node.name][key] = (
+ node.sparse_beta[key].detach().clone())
+ if self.cut_used:
+ for gbidx, general_betas in enumerate(self.cut_beta_params):
+ best_betas['cut'][gbidx] = general_betas.detach().clone()
+ else:
+ if self.cut_used:
+ regular_beta_length = len(betas) - len(self.cut_beta_params)
+ for beta_idx in range(regular_beta_length):
+ # regular beta crown betas
+ best_betas[beta_idx][idx] = betas[beta_idx][idx]
+ for cut_beta_idx in range(len(self.cut_beta_params)):
+ # general cut beta crown general_betas
+ best_betas[regular_beta_length + cut_beta_idx][:, :, idx,
+ :] = betas[regular_beta_length + cut_beta_idx][:, :, idx, :]
+ else:
+ for beta_idx in range(len(betas)):
+ # regular beta crown betas
+ best_betas[beta_idx][idx] = betas[beta_idx][idx]
+
+
+def get_optimized_bounds(
+ self, x=None, aux=None, C=None, IBP=False, forward=False,
+ method='backward', bound_lower=True, bound_upper=False,
+ reuse_ibp=False, return_A=False, average_A=False, final_node_name=None,
+ intermediate_layer_bounds=None, reference_bounds=None,
+ aux_reference_bounds=None, needed_A_dict=None, cutter=None,
+ decision_thresh=None, epsilon_over_decision_thresh=1e-4):
+ """
+ Optimize CROWN lower/upper bounds by alpha and/or beta.
+ """
+
+ opts = self.bound_opts['optimize_bound_args']
+ iteration = opts['iteration']
+ beta = opts['enable_beta_crown']
+ alpha = opts['enable_alpha_crown']
+ opt_coeffs = opts['opt_coeffs']
+ opt_bias = opts['opt_bias']
+ opt_choice = opts['optimizer']
+ single_node_split = opts['single_node_split']
+ assert single_node_split is True
+ keep_best = opts['keep_best']
+ fix_intermediate_layer_bounds = opts['fix_intermediate_layer_bounds']
+ init_alpha = opts['init_alpha']
+ lr_alpha = opts['lr_alpha']
+ lr_beta = opts['lr_beta']
+ lr_cut_beta = opts['lr_cut_beta']
+ lr_intermediate_beta = opts['lr_intermediate_beta']
+ lr_decay = opts['lr_decay']
+ lr_coeffs = opts['lr_coeffs']
+ loss_reduction_func = opts['loss_reduction_func']
+ stop_criterion_func = opts['stop_criterion_func']
+ use_float64_in_last_iteration = opts['use_float64_in_last_iteration']
+ early_stop_patience = opts['early_stop_patience']
+ intermediate_beta_enabled = opts['intermediate_beta']
+ start_save_best = opts['start_save_best']
+ multi_spec_keep_func = opts['multi_spec_keep_func']
+ enable_opt_interm_bounds = self.bound_opts.get(
+ 'enable_opt_interm_bounds', False)
+ sparse_intermediate_bounds = self.bound_opts.get(
+ 'sparse_intermediate_bounds', False)
+ verbosity = self.bound_opts['verbosity']
+
+ assert bound_lower != bound_upper, (
+ 'we can only optimize lower OR upper bound at one time')
+ assert alpha or beta, (
+ 'nothing to optimize, use compute bound instead!')
+
+ if C is not None:
+ self.final_shape = C.size()[:2]
+ self.bound_opts.update({'final_shape': self.final_shape})
+ if init_alpha:
+ # TODO: this should set up aux_reference_bounds.
+ self.init_slope(x, share_slopes=opts['use_shared_alpha'],
+ method=method, c=C, final_node_name=final_node_name)
+
+ # Optimizable activations that are actually used and perturbed
+ optimizable_activations = [
+ n for n in self.optimizable_activations if n.used and n.perturbed]
+ # Relu node that are actually used
+ relus = [n for n in self.relus if n.used and n.perturbed]
+
+ alphas, betas, parameters = [], [], []
+ dense_coeffs_mask = []
+ partial_intermediate_layer_bounds = None
+
+ if alpha:
+ best_alphas = _set_alpha(
+ optimizable_activations, parameters, alphas, lr_alpha)
+
+ if beta:
+ ret_set_beta = _set_beta(
+ self, relus, optimizable_activations, single_node_split,
+ enable_opt_interm_bounds, betas, opt_coeffs, parameters,
+ lr_coeffs, opt_bias, lr_beta, lr_cut_beta, cutter,
+ dense_coeffs_mask)
+ betas, best_betas, coeffs = ret_set_beta[:3]
+ dense_coeffs_mask, best_coeffs, biases, best_biases = ret_set_beta[3:]
+
+ start = time.time()
+
+ if (decision_thresh is not None
+ and isinstance(decision_thresh, torch.Tensor)):
+ if decision_thresh.dim() == 1:
+ # add the spec dim to be aligned with compute_bounds return
+ decision_thresh = decision_thresh.unsqueeze(-1)
+
+
+ if opt_choice == 'adam-autolr':
+ opt = AdamElementLR(parameters)
+ elif opt_choice == 'adam':
+ opt = optim.Adam(parameters)
+ elif opt_choice == 'sgd':
+ opt = optim.SGD(parameters, momentum=0.9)
+ else:
+ raise NotImplementedError(opt_choice)
+
+ # Create a weight vector to scale learning rate.
+ loss_weight = torch.ones(size=(x[0].size(0),), device=x[0].device)
+ scheduler = optim.lr_scheduler.ExponentialLR(opt, lr_decay)
+
+ if verbosity > 0 and intermediate_beta_enabled:
+ self.print_optimized_beta(relus, intermediate_beta_enabled=True)
+
+ # best_intermediate_bounds is linked to aux_reference_bounds!
+ best_intermediate_bounds = {}
+ if (sparse_intermediate_bounds and aux_reference_bounds is None
+ and reference_bounds is not None):
+ aux_reference_bounds = {}
+ for name, (lb, ub) in reference_bounds.items():
+ aux_reference_bounds[name] = [
+ lb.detach().clone(), ub.detach().clone()]
+ if aux_reference_bounds is None:
+ aux_reference_bounds = {}
+
+ with torch.no_grad():
+ pruning_in_iteration = False
+ # for computing the positive domain ratio
+ original_size = x[0].shape[0]
+ preserve_mask = None
+
+ # record the overhead due to extra operations from pruning-in-iteration
+ pruning_time = 0.
+
+ need_grad = True
+ patience = 0
+ for i in range(iteration):
+ if cutter:
+ # cuts may be optimized by cutter
+ self.cut_module = cutter.cut_module
+
+ intermediate_constr = None
+
+ if not fix_intermediate_layer_bounds:
+ # If we still optimize all intermediate neurons, we can use
+ # intermediate_layer_bounds as reference bounds.
+ reference_bounds = intermediate_layer_bounds
+
+ if i == iteration - 1:
+ # No grad update needed for the last iteration
+ need_grad = False
+
+ if (self.device == 'cuda'
+ and torch.get_default_dtype() == torch.float32
+ and use_float64_in_last_iteration):
+ C, x, intermediate_layer_bounds = _to_float64(
+ self, C, x, aux_reference_bounds, intermediate_layer_bounds)
+
+ # we will use last update preserve mask in caller functions to recover
+ # lA, l, u, etc to full batch size
+ self.last_update_preserve_mask = preserve_mask
+ with torch.no_grad() if not need_grad else ExitStack():
+ # ret is lb, ub or lb, ub, A_dict (if return_A is set to true)
+
+ # argument for intermediate_layer_bounds
+ # If we set neuron bounds individually, or if we are optimizing
+ # intermediate layer bounds using beta, we do not set
+ # intermediate_layer_bounds.
+ # When intermediate betas are used, we must set
+ # intermediate_layer_bounds to None because we want to recompute
+ # all intermediate layer bounds.
+ if beta and intermediate_beta_enabled:
+ arg_ilb = partial_intermediate_layer_bounds
+ elif fix_intermediate_layer_bounds:
+ arg_ilb = intermediate_layer_bounds
+ else:
+ arg_ilb = None
+
+ # argument for aux_reference_bounds
+ if sparse_intermediate_bounds:
+ arg_arb = aux_reference_bounds
+ else:
+ arg_arb = None
+
+ ret = self.compute_bounds(
+ x, aux, C, method=method, IBP=IBP, forward=forward,
+ bound_lower=bound_lower, bound_upper=bound_upper,
+ reuse_ibp=reuse_ibp, return_A=return_A,
+ final_node_name=final_node_name, average_A=average_A,
+ intermediate_layer_bounds=arg_ilb,
+ # This is the currently tightest interval, which will be used to
+ # pass split constraints when intermediate betas are used.
+ reference_bounds=reference_bounds,
+ # This is the interval used for checking for unstable neurons.
+ aux_reference_bounds=arg_arb,
+ # These are intermediate layer beta variables and their
+ # corresponding A matrices and biases.
+ intermediate_constr=intermediate_constr,
+ needed_A_dict=needed_A_dict,
+ update_mask=preserve_mask)
+
+ ret_l, ret_u = ret[0], ret[1]
+
+ if (self.cut_used and i % cutter.log_interval == 0
+ and len(self.cut_beta_params) > 0):
+ # betas[-1]: (2(0 lower, 1 upper), spec, batch, num_constrs)
+ if ret_l is not None:
+ print(
+ i, 'lb beta sum:',
+ f'{self.cut_beta_params[-1][0].sum() / ret_l.size(0)},',
+ f'worst {ret_l.min()}')
+ if ret_u is not None:
+ print(
+ i, 'lb beta sum:',
+ f'{self.cut_beta_params[-1][1].sum() / ret_u.size(0)},',
+ f'worst {ret_u.min()}')
+
+ if i == 0:
+ # save results at the first iteration
+ best_ret = []
+ best_ret_l = _save_ret_first_time(
+ ret[0], float('-inf'), x, best_ret)
+ best_ret_u = _save_ret_first_time(
+ ret[1], float('inf'), x, best_ret)
+
+ for node in optimizable_activations:
+ new_intermediate = [
+ node.inputs[0].lower.detach().clone(),
+ node.inputs[0].upper.detach().clone()]
+ best_intermediate_bounds[node.name] = new_intermediate
+ if sparse_intermediate_bounds:
+ # Always using the best bounds so far as the reference
+ # bounds.
+ aux_reference_bounds[node.inputs[0].name] = new_intermediate
+
+ l = ret_l
+ # Reduction over the spec dimension.
+ if ret_l is not None and ret_l.shape[1] != 1:
+ l = loss_reduction_func(ret_l)
+ u = ret_u
+ if ret_u is not None and ret_u.shape[1] != 1:
+ u = loss_reduction_func(ret_u)
+
+ # full_l, full_ret_l and full_u, full_ret_u is used for update the best
+ full_ret_l, full_ret_u = ret_l, ret_u
+ full_l = l
+ full_ret = ret
+
+ # positive domains may already be filtered out, so we use all domains -
+ # negative domains to compute
+ if decision_thresh is not None:
+ if (isinstance(decision_thresh, torch.Tensor)
+ and decision_thresh.numel() > 1
+ and preserve_mask is not None):
+ if decision_thresh.shape[-1] == 1:
+ # single spec with pruned domains
+ negative_domain = (
+ ret_l.view(-1)
+ <= decision_thresh[preserve_mask].view(-1)).sum()
+ else:
+ # multiple spec with pruned domains
+ negative_domain = multi_spec_keep_func(
+ ret_l <= decision_thresh[preserve_mask]).sum()
+ else:
+ if ret_l.shape[-1] == 1:
+ # single spec
+ negative_domain = (
+ ret_l.view(-1) <= decision_thresh.view(-1)).sum()
+ else:
+ # multiple spec
+ negative_domain = multi_spec_keep_func(
+ ret_l <= decision_thresh).sum()
+ positive_domain_num = original_size - negative_domain
+ else:
+ positive_domain_num = -1
+ positive_domain_ratio = float(
+ positive_domain_num) / float(original_size)
+ # threshold is 10% by default
+ next_iter_pruning_in_iteration = (
+ opts['pruning_in_iteration'] and decision_thresh is not None
+ and positive_domain_ratio > opts['pruning_in_iteration_threshold'])
+
+ if pruning_in_iteration:
+ stime = time.time()
+ if return_A:
+ raise Exception(
+ 'Pruning in iteration optimization does not support '
+ 'return A yet. '
+ 'Please fix or discard this optimization by setting '
+ '--disable_pruning_in_iteration '
+ 'or bab: pruning_in_iteration: false')
+ now_preserve_mask = _get_preserve_mask(
+ decision_thresh, ret_l, preserve_mask, multi_spec_keep_func)
+ # prune C
+ if C is not None and C.shape[0] == x[0].shape[0]:
+ C = C[now_preserve_mask] # means C is also batch specific
+ # prune x
+ x, pre_prune_size = _prune_x(x, now_preserve_mask)
+ # prune bounds
+ ret_prune = _prune_bounds_by_mask(
+ now_preserve_mask, decision_thresh, ret_l, ret_u, ret,
+ preserve_mask, epsilon_over_decision_thresh, original_size,
+ loss_reduction_func, beta, intermediate_beta_enabled,
+ fix_intermediate_layer_bounds,
+ intermediate_layer_bounds, aux_reference_bounds,
+ partial_intermediate_layer_bounds, pre_prune_size)
+ full_l, full_ret_l = ret_prune[:2]
+ # ret_prune[2] is full_u which is unused
+ full_ret_u, full_ret, preserve_mask_next = ret_prune[3:]
+ pruning_time += time.time() - stime
+
+ loss_ = l if bound_lower else -u
+ stop_criterion = stop_criterion_func(
+ full_ret_l) if bound_lower else stop_criterion_func(-full_ret_u)
+ if (type(stop_criterion) != bool
+ and stop_criterion.numel() > 1 and pruning_in_iteration):
+ stop_criterion = stop_criterion[preserve_mask]
+ total_loss = -1 * loss_
+ if type(stop_criterion) == bool:
+ loss = total_loss.sum() * (not stop_criterion)
+ else:
+ loss = (total_loss * stop_criterion.logical_not()).sum()
+
+
+ stop_criterion_final = isinstance(
+ stop_criterion, torch.Tensor) and stop_criterion.all()
+
+ if i == iteration - 1:
+ best_ret = list(best_ret)
+ if best_ret[0] is not None:
+ best_ret[0] = best_ret[0].to(torch.get_default_dtype())
+ if best_ret[1] is not None:
+ best_ret[1] = best_ret[1].to(torch.get_default_dtype())
+
+ if (i == iteration - 1 and self.device == 'cuda'
+ and torch.get_default_dtype() == torch.float32
+ and use_float64_in_last_iteration):
+ total_loss, x, full_ret = _to_default_dtype(
+ self, x, total_loss, full_ret, ret, best_intermediate_bounds,
+ return_A)
+
+ with torch.no_grad():
+ # for lb and ub, we update them in every iteration since updating
+ # them is cheap
+ need_update = False
+ if keep_best:
+ if best_ret_u is not None:
+ ret_upd = _update_best_ret(
+ full_ret_u, best_ret_u, full_ret, best_ret, need_update,
+ idx=1)
+ best_ret_u, best_ret, idx_mask, idx, need_update = ret_upd
+ if best_ret_l is not None:
+ ret_upd = _update_best_ret(
+ full_ret_l, best_ret_l, full_ret, best_ret, need_update,
+ idx=0)
+ best_ret_l, best_ret, idx_mask, idx, need_update = ret_upd
+ else:
+ # Not saving the best, just keep the last iteration.
+ if full_ret[0] is not None:
+ best_ret[0] = full_ret[0]
+ if full_ret[1] is not None:
+ best_ret[1] = full_ret[1]
+ if return_A:
+ # FIXME: A should also be updated by idx.
+ best_ret = [best_ret[0], best_ret[1], full_ret[2]]
+
+ # Save variables if this is the best iteration.
+ # To save computational cost, we only check keep_best at the first
+ # (in case divergence) and second half iterations
+ # or before early stop by either stop_criterion or
+ # early_stop_patience reached
+ # if i < 1 or i > iteration / 2 or stop_criterion_final or
+ # patience == early_stop_patience:
+ if (i < 1 or i > int(iteration * start_save_best)
+ or stop_criterion_final or patience == early_stop_patience):
+ if need_update:
+ patience = 0 # bounds improved, reset patience
+ local_idx = None
+ # for update propose, we condition the idx to update only
+ # on domains preserved
+ if pruning_in_iteration:
+ # local sparse index of preserved samples where
+ # idx = true
+ local_idx = idx_mask[preserve_mask].nonzero().view(-1)
+ # idx is global sparse index of preserved samples where
+ # idx = true
+ new_idx = torch.zeros_like(
+ idx_mask, dtype=torch.bool, device=idx.device)
+ new_idx[preserve_mask] = idx_mask[preserve_mask]
+ idx = new_idx.nonzero().view(-1)
+
+ _update_optimizable_activations(
+ optimizable_activations, pruning_in_iteration,
+ intermediate_layer_bounds,
+ fix_intermediate_layer_bounds,
+ best_intermediate_bounds, idx, local_idx, alpha,
+ best_alphas)
+
+ if beta and single_node_split:
+ _update_best_beta(
+ self, enable_opt_interm_bounds, betas,
+ optimizable_activations, best_betas, idx)
+
+ else:
+ patience += 1
+
+ if os.environ.get('AUTOLIRPA_DEBUG_OPT', False):
+ print(f'****** iter [{i}]',
+ f'loss: {loss.item()}, lr: {opt.param_groups[0]["lr"]}')
+
+ if stop_criterion_final:
+ print(f'\nall verified at {i}th iter')
+ break
+
+ if patience > early_stop_patience:
+ print(
+ f'Early stop at {i}th iter due to {early_stop_patience}'
+ ' iterations no improvement!')
+ break
+
+ current_lr = [param_group['lr'] for param_group in opt.param_groups]
+
+ opt.zero_grad(set_to_none=True)
+
+ if verbosity > 2:
+ print(
+ f'*** iter [{i}]\n', f'loss: {loss.item()}',
+ total_loss.squeeze().detach().cpu().numpy(), 'lr: ',
+ current_lr)
+ if beta:
+ self.print_optimized_beta(relus, intermediate_beta_enabled)
+ if opt_coeffs:
+ for co in coeffs:
+ print(f'coeff sum: {co.abs().sum():.5g}')
+ if beta and i == 0 and verbosity > 2:
+ breakpoint()
+
+ if i != iteration - 1:
+ # we do not need to update parameters in the last step since the
+ # best result already obtained
+ loss.backward()
+ # All intermediate variables are not needed at this point.
+ self._clear_and_set_new(None)
+ if opt_choice == 'adam-autolr':
+ opt.step(lr_scale=[loss_weight, loss_weight])
+ else:
+ opt.step()
+
+ if beta:
+ # Clipping to >=0.
+ for b in betas:
+ b.data = (b >= 0) * b.data
+ for dmi in range(len(dense_coeffs_mask)):
+ # apply dense mask to the dense split coeffs matrix
+ coeffs[dmi].data = (
+ dense_coeffs_mask[dmi].float() * coeffs[dmi].data)
+
+
+ if alpha:
+ for m in optimizable_activations:
+ m.clip_alpha_()
+
+ scheduler.step()
+
+ if pruning_in_iteration:
+ preserve_mask = preserve_mask_next
+ if not pruning_in_iteration and next_iter_pruning_in_iteration:
+ # init preserve_mask etc
+ preserve_mask = torch.arange(
+ 0, x[0].shape[0], device=x[0].device, dtype=torch.long)
+ pruning_in_iteration = True
+
+ if pruning_in_iteration:
+ # overwrite pruned cells in best_ret by threshold + eps
+ if return_A:
+ fin_l, fin_u, fin_A = best_ret
+ else:
+ fin_l, fin_u = best_ret
+ fin_A = None
+ if fin_l is not None:
+ new_fin_l = full_ret_l
+ new_fin_l[preserve_mask] = fin_l[preserve_mask]
+ fin_l = new_fin_l
+ if fin_u is not None:
+ new_fin_u = full_ret_u
+ new_fin_u[preserve_mask] = fin_u[preserve_mask]
+ fin_u = new_fin_u
+ if return_A:
+ best_ret = (fin_l, fin_u, fin_A)
+ else:
+ best_ret = (fin_l, fin_u)
+
+ if verbosity > 3:
+ breakpoint()
+
+ if keep_best:
+ def update_best(dest, src):
+ for item_dest, item_src in zip(dest, src):
+ if enable_opt_interm_bounds:
+ for key in item_dest.keys():
+ item_dest[key].data = item_src[key].data
+ else:
+ item_dest.data = item_src.data
+
+ # Set all variables to their saved best values.
+ with torch.no_grad():
+ for idx, node in enumerate(optimizable_activations):
+ if alpha:
+ # Assigns a new dictionary.
+ node.alpha = best_alphas[node.name]
+ # Update best intermediate layer bounds only when they are
+ # optimized. If they are already fixed in
+ # intermediate_layer_bounds, then do nothing.
+ best_intermediate = best_intermediate_bounds[node.name]
+ node.inputs[0].lower.data = best_intermediate[0].data
+ node.inputs[0].upper.data = best_intermediate[1].data
+ if beta:
+ if (single_node_split and hasattr(node, 'sparse_beta')
+ and node.sparse_beta is not None):
+ if enable_opt_interm_bounds:
+ for key in node.sparse_beta.keys():
+ node.sparse_beta[key].copy_(
+ best_betas[node.name][key])
+ else:
+ node.sparse_beta.copy_(best_betas[idx])
+ else:
+ update_best(betas, best_betas)
+ if opt_coeffs:
+ update_best(coeffs, best_coeffs)
+ if opt_bias:
+ update_best(biases, best_biases)
+ if self.cut_used:
+ regular_beta_length = len(betas) - len(self.cut_beta_params)
+ for ii in range(len(self.cut_beta_params)):
+ self.cut_beta_params[ii].data = best_betas[
+ regular_beta_length + ii].data
+
+ if (intermediate_layer_bounds is not None
+ and not fix_intermediate_layer_bounds):
+ for l in self._modules.values():
+ if (l.name in intermediate_layer_bounds.keys()
+ and hasattr(l, 'lower')):
+ l.lower = torch.max(
+ l.lower, intermediate_layer_bounds[l.name][0])
+ l.upper = torch.min(
+ l.upper, intermediate_layer_bounds[l.name][1])
+ infeasible_neurons = l.lower > l.upper
+ if infeasible_neurons.any():
+ print(
+ f'Infeasibility detected in layer {l.name}.',
+ infeasible_neurons.sum().item(),
+ infeasible_neurons.nonzero()[:, 0])
+
+ if verbosity > 0:
+ if self.cut_used and beta:
+ print(
+ 'first 10 best general betas:',
+ best_betas[-1].view(2, -1)[0][:10], 'sum:',
+ best_betas[-1][0].sum().item())
+ if best_ret_l is not None:
+ # FIXME: unify the handling of l and u.
+ print(
+ 'best_l after optimization:',
+ best_ret_l.sum().item(), 'with beta sum per layer:',
+ [p.sum().item() for p in betas])
+ print('alpha/beta optimization time:', time.time() - start)
+
+ for node in optimizable_activations:
+ node.opt_end()
+
+ # update pruning ratio
+ if (opts['pruning_in_iteration'] and decision_thresh is not None
+ and full_l.numel() > 0):
+ stime = time.time()
+ with torch.no_grad():
+ if isinstance(decision_thresh, torch.Tensor):
+ if decision_thresh.shape[-1] == 1:
+ neg_domain_num = torch.sum(
+ full_ret_l.view(-1) <= decision_thresh.view(-1)).item()
+ else:
+ neg_domain_num = torch.sum(multi_spec_keep_func(
+ full_ret_l <= decision_thresh)).item()
+ else:
+ if full_l.shape[-1] == 1:
+ neg_domain_num = torch.sum(
+ full_ret_l.view(-1) <= decision_thresh).item()
+ else:
+ neg_domain_num = torch.sum(multi_spec_keep_func(
+ full_ret_l <= decision_thresh)).item()
+ now_pruning_ratio = (1.0 -
+ float(neg_domain_num) / float(full_l.shape[0]))
+ print('pruning_in_iteration open status:', pruning_in_iteration)
+ print(
+ 'ratio of positive domain =', full_l.shape[0] - neg_domain_num,
+ '/', full_l.numel(), '=', now_pruning_ratio)
+ pruning_time += time.time() - stime
+ print('pruning-in-iteration extra time:', pruning_time)
+
+ return best_ret
+
+
+def init_slope(
+ self, x, share_slopes=False, method='backward',
+ c=None, bound_lower=True, bound_upper=True, final_node_name=None,
+ intermediate_layer_bounds=None, activation_opt_params=None,
+ skip_bound_compute=False):
+ for node in self.optimizable_activations:
+ # initialize the parameters
+ node.opt_init()
+
+ if (not skip_bound_compute or intermediate_layer_bounds is None or
+ activation_opt_params is None or not all(
+ [relu.name in activation_opt_params for relu in self.relus])):
+ skipped = False
+ # if new interval is None, then CROWN interval is not present
+ # in this case, we still need to redo a CROWN pass to initialize
+ # lower/upper
+ with torch.no_grad():
+ l, u = self.compute_bounds(
+ x=x, C=c, method=method, bound_lower=bound_lower,
+ bound_upper=bound_upper, final_node_name=final_node_name,
+ intermediate_layer_bounds=intermediate_layer_bounds)
+ else:
+ # we skip, but we still would like to figure out the "used",
+ # "perturbed", "backward_from" of each note in the graph
+ skipped = True
+ # this set the "perturbed" property
+ self._set_input(
+ *x, intermediate_layer_bounds=intermediate_layer_bounds)
+
+ final = self.final_node(
+ ) if final_node_name is None else self[final_node_name]
+ self._set_used_nodes(final)
+
+ self.backward_from = {node: [final] for node in self._modules}
+
+ final_node_name = final_node_name or self.final_name
+
+ init_intermediate_bounds = {}
+ for node in self.optimizable_activations:
+ if not node.used or not node.perturbed:
+ continue
+ start_nodes = []
+ if method in ['forward', 'forward+backward']:
+ start_nodes.append(('_forward', 1, None))
+ if method in ['backward', 'forward+backward']:
+ start_nodes += self.get_alpha_crown_start_nodes(
+ node, c=c, share_slopes=share_slopes,
+ final_node_name=final_node_name)
+ if skipped:
+ node.restore_optimized_params(activation_opt_params[node.name])
+ else:
+ node.init_opt_parameters(start_nodes)
+ init_intermediate_bounds[node.inputs[0].name] = (
+ [node.inputs[0].lower.detach(), node.inputs[0].upper.detach()])
+
+ if self.bound_opts['verbosity'] >= 1:
+ print('Optimizable variables initialized.')
+ if skip_bound_compute:
+ return init_intermediate_bounds
+ else:
+ return l, u, init_intermediate_bounds
diff --git a/auto_LiRPA/parse_graph.py b/auto_LiRPA/parse_graph.py
index a2d054f..3137b7c 100644
--- a/auto_LiRPA/parse_graph.py
+++ b/auto_LiRPA/parse_graph.py
@@ -1,15 +1,15 @@
import os
import torch
+from torch.onnx.utils import _optimize_graph
+from torch.onnx.symbolic_helper import _set_opset_version
from collections import OrderedDict
-import re
from collections import namedtuple
-from torch.onnx import OperatorExportTypes
-from packaging import version
+import re
from .bounded_tensor import BoundedTensor, BoundedParameter
from .utils import logger, unpack_inputs
Node = namedtuple('Node', (
- 'name', 'ori_name', 'inputs', 'attr', 'op', 'param', 'input_index',
+ 'name', 'ori_name', 'inputs', 'attr', 'op', 'param', 'input_index',
'bound_node', 'output_index', 'perturbation'), defaults=(None,) * 10)
def get_node_name(node):
@@ -17,14 +17,14 @@ def get_node_name(node):
def parse_graph(graph, inputs, params):
input_all = []
- input_used = []
+ input_used = []
scope = {}
for n in graph.inputs():
input_all.append(n.debugName())
for n in graph.nodes():
n_inputs = [get_node_name(i) for i in n.inputs()]
for inp in n.inputs():
- input_used.append(inp.debugName())
+ input_used.append(inp.debugName())
for out in n.outputs():
scope[get_node_name(out)] = n.scopeName()
for node in graph.inputs():
@@ -39,17 +39,18 @@ def parse_graph(graph, inputs, params):
def name_with_scope(node):
name = get_node_name(node)
return '/'.join([scope[name], name])
-
+
nodesOP = []
for n in graph.nodes():
attrs = {k: n[k] for k in n.attributeNames()}
n_inputs = [name_with_scope(i) for i in n.inputs()]
for i, out in enumerate(list(n.outputs())):
+
nodesOP.append(Node(**{'name': name_with_scope(out),
'op': n.kind(),
'inputs': n_inputs,
'attr': attrs,
- 'output_index': i,
+ 'output_index': i,
}))
# filter out input nodes in `graph.inputs()` that are actually used
@@ -64,7 +65,7 @@ def name_with_scope(node):
# filter out input nodes in `inputs` that are actually used
inputs_unpacked = unpack_inputs(inputs)
assert len(list(graph.inputs())) == len(inputs_unpacked) + len(params)
- inputs = [inputs_unpacked[i] for i in range(len(inputs_unpacked)) if used_by_index[i]]
+ inputs = [inputs_unpacked[i] for i in range(len(inputs_unpacked)) if used_by_index[i]]
# index of the used inputs among all the inputs
input_index = [i for i in range(len(inputs_unpacked)) if used_by_index[i]]
# Add a name to all inputs
@@ -72,7 +73,7 @@ def name_with_scope(node):
# filter out params that are actually used
params = [params[i] for i in range(len(params)) if used_by_index[i + len(inputs_unpacked)]]
inputs_and_params = inputs + params
- assert len(nodesIn) == len(inputs_and_params)
+ assert len(nodesIn) == len(inputs_and_params)
# output nodes of the module
nodesOut = []
@@ -81,7 +82,7 @@ def name_with_scope(node):
nodesOut.append(name_with_scope(n))
for i, n in enumerate(nodesIn):
- if (isinstance(inputs_and_params[i][1], BoundedTensor) or
+ if (isinstance(inputs_and_params[i][1], BoundedTensor) or
isinstance(inputs_and_params[i][1], BoundedParameter)):
perturbation = inputs_and_params[i][1].ptb
else:
@@ -92,10 +93,10 @@ def name_with_scope(node):
nodesIn[i] = Node(**{'name': name_with_scope(n),
'ori_name': inputs_and_params[i][0],
'op': 'Parameter',
- 'inputs': [],
+ 'inputs': [],
'attr': str(n.type()),
'param': inputs_and_params[i][1] if i >= len(inputs) else None,
- # index among all the inputs including unused ones
+ # index among all the inputs including unused ones
'input_index': input_index[i] if i < len(inputs) else None,
# Input nodes may have perturbation, if they are wrapped in BoundedTensor or BoundedParameters
'perturbation': perturbation, })
@@ -123,7 +124,7 @@ def _get_jit_params(module, param_exclude, param_include):
return params
-"""Construct a template for the module output with `None` representing places
+"""Construct a template for the module output with `None` representing places
to be filled with tensor results"""
def get_output_template(out):
if isinstance(out, torch.Tensor):
@@ -142,12 +143,10 @@ def get_output_template(out):
def parse_module(module, inputs, param_exclude=".*AuxLogits.*", param_include=None):
params = _get_jit_params(module, param_exclude=param_exclude, param_include=param_include)
- # PyTorch>=1.5 is required here
trace, out = torch.jit._get_trace_graph(module, inputs)
- from torch.onnx.symbolic_helper import _set_opset_version
_set_opset_version(12)
- trace_graph = torch.onnx.utils._optimize_graph(trace, torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, params_dict={})
-
+ trace_graph = _optimize_graph(
+ trace, torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, params_dict={})
logger.debug('trace_graph: {}'.format(trace_graph))
if int(os.environ.get('AUTOLIRPA_DEBUG_GRAPH', 0)) > 0:
@@ -158,7 +157,7 @@ def parse_module(module, inputs, param_exclude=".*AuxLogits.*", param_include=No
if not isinstance(inputs, tuple):
inputs = (inputs, )
-
+
nodesOP, nodesIn, nodesOut = parse_graph(trace_graph, tuple(inputs), tuple(params))
for i in range(len(nodesOP)):
diff --git a/auto_LiRPA/patches.py b/auto_LiRPA/patches.py
new file mode 100644
index 0000000..87a7cbe
--- /dev/null
+++ b/auto_LiRPA/patches.py
@@ -0,0 +1,502 @@
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+
+def insert_zeros(image, s):
+ """
+ Insert s columns and rows 0 between every pixel in the image. For example:
+ image = [[1, 2, 3],
+ [4, 5, 6],
+ [7, 8, 9]]
+ s = 2
+ output = [[1, 0, 0, 2, 0, 0, 3],
+ [0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0],
+ [4, 0, 0, 5, 0, 0, 6],
+ [0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0],
+ [7, 0, 0, 8, 0, 0, 9]]
+ """
+ if s <= 0:
+ return image
+ matrix = torch.zeros(size=(image.size(0), image.size(1), image.size(2) * (s+1) - s, image.size(3) * (s+1) - s), dtype=image.dtype, device=image.device)
+ matrix_stride = matrix.stride()
+ selected_matrix = torch.as_strided(matrix, [
+ # Shape of the output matrix.
+ matrix.size(0), # Batch size.
+ matrix.size(1), # Channel.
+ image.size(2), # H (without zeros)
+ image.size(3), # W (without zeros)
+ ], [
+ # Stride of the output matrix.
+ matrix_stride[0], # Batch size dimension, keep using the old stride.
+ matrix_stride[1], # Channel dimension.
+ matrix_stride[2] * (s + 1), # Move s+1 rows.
+ s+1, # Move s+1 pixels.
+ ]) # Move a pixel (on the width direction).
+ selected_matrix[:] = image
+ return matrix
+
+
+def remove_zeros(image, s):
+ if s <= 0:
+ return image
+ matrix_stride = image.stride()
+ return torch.as_strided(image, [
+ # Shape of the output matrix.
+ *image.shape[:-2],
+ (image.size(-2) + 1) // 2, # H (without zeros)
+ (image.size(-1) + 1) // 2, # W (without zeros)
+ ], [
+ # Stride of the output matrix.
+ *matrix_stride[:-2],
+ matrix_stride[-2] * (s + 1), # Move s+1 rows.
+ matrix_stride[-1] * (s + 1), # Move s+1 pixels.
+ ])
+
+
+def unify_shape(shape):
+ """Convert shapes to 4-tuple."""
+ if shape is not None:
+ if isinstance(shape, int):
+ shape = (shape, shape, shape, shape)
+ if len(shape) == 2:
+ shape = (shape[1], shape[1], shape[0], shape[0])
+ assert len(shape) == 4
+ return shape
+
+
+def is_shape_used(shape, expected=0):
+ if isinstance(shape, int):
+ return shape != expected
+ else:
+ return sum(shape) != expected
+
+
+class Patches:
+ """
+ A special class which denotes a convoluntional operator as a group of patches
+ the shape of Patches.patches is [batch_size, num_of_patches, out_channel, in_channel, M, M]
+ M is the size of a single patch
+ Assume that we have a conv2D layer with w.weight(out_channel, in_channel, M, M), stride and padding applied on an image (N * N)
+ num_of_patches = ((N + padding * 2 - M)//stride + 1) ** 2
+ Here we only consider kernels with the same H and W
+ """
+ def __init__(
+ self, patches=None, stride=1, padding=0, shape=None, identity=0,
+ unstable_idx=None, output_shape=None, inserted_zeros=0, output_padding=0, input_shape=None):
+ # Shape: [batch_size, num_of_patches, out_channel, in_channel, M, M]
+ # M is the size of a single patch
+ # Assume that we have a conv2D layer with w.weight(out_channel, in_channel, M, M), stride and padding applied on an image (N * N)
+ # num_of_patches = ((N + padding * 2 - M)//stride + 1) ** 2
+ # Here we only consider kernels with the same H and W
+ self.patches = patches
+ self.stride = stride
+ self.padding = padding
+ self.shape = shape
+ self.identity = identity
+ self.unstable_idx = unstable_idx
+ self.output_shape = output_shape
+ self.input_shape = input_shape
+ self.inserted_zeros = inserted_zeros
+ self.output_padding = output_padding
+ self.simplify()
+
+ def __add__(self, other):
+ if isinstance(other, Patches):
+ # Insert images with zero to make stride the same, if necessary.
+ assert self.stride == other.stride
+ if self.unstable_idx is not None or other.unstable_idx is not None:
+ if self.unstable_idx is not other.unstable_idx: # Same tuple object.
+ raise ValueError('Please set bound option "sparse_conv_intermediate_bounds" to False to run this model.')
+ assert self.output_shape == other.output_shape
+ A1 = self.patches
+ A2 = other.patches
+ # change paddings to merge the two patches
+ sp = torch.tensor(unify_shape(self.padding))
+ op = torch.tensor(unify_shape(other.padding))
+ if (sp - op).abs().sum().item() > 0:
+ if (sp - op >= 0).all():
+ A2 = F.pad(A2, (sp - op).tolist())
+ pass
+ elif (sp - op <= 0).all():
+ A1 = F.pad(A1, (op - sp).tolist())
+ else:
+ raise ValueError("Unsupported padding size")
+ ret = A1 + A2
+ return Patches(ret, other.stride, torch.max(sp, op).tolist(),
+ ret.shape, unstable_idx=self.unstable_idx, output_shape=self.output_shape,
+ inserted_zeros=self.inserted_zeros, output_padding=self.output_padding)
+ else:
+ assert self.inserted_zeros == 0
+ assert not is_shape_used(self.output_padding)
+ # Patches has shape (out_c, batch, out_h, out_w, in_c, h, w).
+ input_shape = other.shape[3:]
+ matrix = other
+ pieces = self.patches
+ if pieces.ndim == 9:
+ pieces = pieces.transpose(0, 1)
+ pieces = pieces.view(pieces.shape[0], -1, pieces.shape[3], pieces.shape[4], pieces.shape[5]*pieces.shape[6], pieces.shape[7], pieces.shape[8]).transpose(0,1)
+ if pieces.ndim == 8:
+ pieces = pieces.transpose(0, 1)
+ pieces = pieces.view(pieces.shape[0], -1, pieces.shape[3], pieces.shape[4], pieces.shape[5], pieces.shape[6], pieces.shape[7]).transpose(0,1)
+ A1_matrix = patches_to_matrix(
+ pieces, input_shape, self.stride, self.padding,
+ output_shape=self.output_shape, unstable_idx=self.unstable_idx)
+ return A1_matrix.transpose(0, 1) + matrix
+
+ def create_similar(self, patches=None, stride=None, padding=None, identity=None,
+ unstable_idx=None, output_shape=None, inserted_zeros=None, output_padding=None,
+ input_shape=None):
+ """
+ Create a new Patches object with new patches weights, and keep other properties the same.
+ """
+ new_patches = self.patches if patches is None else patches
+ return Patches(
+ new_patches,
+ stride=self.stride if stride is None else stride,
+ padding=self.padding if padding is None else padding,
+ shape=new_patches.shape,
+ identity=self.identity if identity is None else identity,
+ unstable_idx=self.unstable_idx if unstable_idx is None else unstable_idx,
+ output_shape=self.output_shape if output_shape is None else output_shape,
+ inserted_zeros=self.inserted_zeros if inserted_zeros is None else inserted_zeros,
+ output_padding=self.output_padding if output_padding is None else output_padding,
+ input_shape=self.input_shape if input_shape is None else input_shape,
+ )
+
+ def to_matrix(self, input_shape):
+ assert self.inserted_zeros == 0
+ assert not is_shape_used(self.output_padding)
+ return patches_to_matrix(self.patches, input_shape, self.stride, self.padding, self.output_shape, self.unstable_idx)
+
+ def simplify(self):
+ """Merge stride and inserted_zeros; if they are the same they can cancel out."""
+ stride = [self.stride, self.stride] if isinstance(self.stride, int) else self.stride
+ if self.inserted_zeros > 0 and self.inserted_zeros + 1 == stride[0] and stride[0] == stride[1]:
+ # print(f'before simplify: patches={self.patches.size()} padding={self.padding}, stride={self.stride}, output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}')
+ full_stride = [stride[1], stride[1], stride[0], stride[0]]
+ # output_padding = tuple(p // s for p, s in zip(output_padding, full_stride))
+ self.padding = tuple(p // s - o for p, s, o in zip(self.padding, full_stride, unify_shape(self.output_padding)))
+ self.patches = remove_zeros(self.patches, self.inserted_zeros)
+ self.stride = 1
+ self.inserted_zeros = 0
+ self.output_padding = 0
+ # print(f'after simplify: patches={self.patches.size()} padding={self.padding}, stride={self.stride}, output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}')
+
+ def matmul(self, input, patch_abs=False, input_shape=None):
+ """
+ Broadcast multiplication for patches and a matrix.
+
+ Input shape: (batch_size, in_c, in_h, in_w).
+ If the dim of in_c, in_h, in_w = 1, the the input will be expand by given input_shape to support broadcast
+
+ Output shape: [batch_size, unstable_size] when unstable_idx is not None,
+ [batch_size, out_c, out_h, out_w] when unstable_idx is None,
+ """
+
+ patches = self.patches
+ if patch_abs:
+ patches = patches.abs()
+
+ if input_shape is not None:
+ # For cases that input only has fewer dimensions like (1, in_c, 1, 1)
+ input = input.expand(input_shape)
+ # Expand to (batch_size, in_c, in_h, in_w)
+
+ # unfold the input as [batch_size, out_h, out_w, in_c, H, W]
+ unfold_input = inplace_unfold(
+ input, kernel_size=patches.shape[-2:],
+ padding=self.padding, stride=self.stride,
+ inserted_zeros=self.inserted_zeros, output_padding=self.output_padding)
+ if self.unstable_idx is not None:
+ # We need to add a out_c dimension and select from it.
+ unfold_input = unfold_input.unsqueeze(0).expand(self.output_shape[1], -1, -1, -1, -1, -1, -1)
+ # Shape: [unstable_size, batch_size, in_c, H, W].
+ # Here unfold_input will match this shape.
+ unfold_input = unfold_input[self.unstable_idx[0], :, self.unstable_idx[1], self.unstable_idx[2]]
+ # shape: [batch_size, unstable_size].
+ return torch.einsum('sbchw,sbchw->bs', unfold_input, patches)
+ else:
+ # shape: [batch_size, out_c, out_h, out_w].
+ return torch.einsum('bijchw,sbijchw->bsij', unfold_input, patches)
+
+
+def compute_patches_stride_padding(input_shape, patches_padding, patches_stride, op_padding, op_stride, inserted_zeros=0, output_padding=0, simplify=True):
+ """
+ Compute stride and padding after a conv layer with patches mode.
+ """
+ for p in (patches_padding, patches_stride, op_padding, op_stride):
+ assert isinstance(p, int) or (isinstance(p, (list, tuple)) and (len(p) == 2 or len(p) == 4))
+ # If p is int, then same padding on all 4 sides.
+ # If p is 2-tuple, then it is padding p[0] on both sides of H, p[1] on both sides of W
+ # If p is 4-tuple, then it is padding p[2], p[3] on top and bottom sides of H, p[0] and p[1] on left and right sides of W
+
+ # If any of the inputs are not tuple/list, we convert them to tuple.
+ full_patch_padding, full_op_padding, full_patch_stride, full_op_stride = [
+ (p, p) if isinstance(p, int) else p for p in [patches_padding, op_padding, patches_stride, op_stride]]
+ full_patch_padding, full_op_padding, full_patch_stride, full_op_stride = [
+ (p[1], p[1], p[0], p[0]) if len(p) == 2 else p for p in [full_patch_padding, full_op_padding, full_patch_stride, full_op_stride]]
+ # Compute the new padding and stride after this layer.
+ new_padding = tuple(pp * os + op * (inserted_zeros + 1) for pp, op, os in zip(full_patch_padding, full_op_padding, full_op_stride))
+ new_stride = tuple(ps * os for ps, os in zip(full_patch_stride, full_op_stride))
+
+ output_padding = unify_shape(output_padding)
+ new_output_padding = (output_padding[0], # Left
+ output_padding[1] + inserted_zeros * input_shape[3] % full_op_stride[2], # Right
+ output_padding[2], # Top
+ output_padding[3] + inserted_zeros * input_shape[2] % full_op_stride[0]) # Bottom
+
+ # Merge into a single number if all numbers are identical.
+ if simplify:
+ if new_padding.count(new_padding[0]) == len(new_padding):
+ new_padding = new_padding[0]
+ if new_stride.count(new_stride[0]) == len(new_stride):
+ new_stride = new_stride[0]
+
+ return new_padding, new_stride, new_output_padding
+
+
+def patches_to_matrix(pieces, input_shape, stride, padding, output_shape=None, unstable_idx=None):
+ """Converting a Patches piece into a full dense matrix."""
+ if type(padding) == int:
+ padding = (padding, padding, padding, padding)
+
+ if pieces.ndim == 9:
+ # Squeeze two additional dimensions for output and input respectively
+ assert pieces.shape[1] == 1 and pieces.shape[5] == 1
+ pieces = pieces.reshape(
+ pieces.shape[0], *pieces.shape[2:5],
+ *pieces.shape[6:]
+ )
+
+ if unstable_idx is None:
+ assert pieces.ndim == 7
+ # Non-sparse pieces, with shape (out_c, batch, out_h, out_w, c, h, w).
+ output_channel, batch_size, output_x, output_y = pieces.shape[:4]
+ else:
+ batch_size = pieces.shape[1]
+ output_channel, output_x, output_y = output_shape[1:]
+ input_channel, kernel_x, kernel_y = pieces.shape[-3:]
+ input_x, input_y = input_shape[-2:]
+
+ if unstable_idx is None:
+ # Fix all patches in a full A matrix.
+ A_matrix = torch.zeros(batch_size, output_channel, output_x, output_y, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device, dtype=pieces.dtype)
+ # Save its orignal stride.
+ orig_stride = A_matrix.stride()
+ # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution.
+ # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient.
+ matrix_strided = torch.as_strided(A_matrix, [batch_size, output_channel, output_x, output_y, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], orig_stride[2], orig_stride[3], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[4], input_y + padding[0] + padding[1], 1])
+ # Now we need to fill the conv kernel parameters into the last three dimensions of matrix_strided.
+ first_indices = torch.arange(output_x * output_y, device=pieces.device)
+ second_indices = torch.div(first_indices, output_y, rounding_mode="trunc")
+ third_indices = torch.fmod(first_indices, output_y)
+ # pieces have shape (out_c, batch, out_h, out_w, c, h, w).
+ pieces = pieces.transpose(0, 1) # pieces has the out_c dimension at the front, need to move it to the second.
+ matrix_strided[:,:,second_indices,third_indices,second_indices,third_indices,:,:,:] = pieces.reshape(*pieces.shape[:2], -1, *pieces.shape[4:])
+ A_matrix = A_matrix.view(batch_size, output_channel * output_x * output_y, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1])
+ else:
+ # Fill only a selection of patches.
+ # Create only a partial A matrix.
+ unstable_size = unstable_idx[0].numel()
+ A_matrix = torch.zeros(batch_size, unstable_size, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device, dtype=pieces.dtype)
+ # Save its orignal stride.
+ orig_stride = A_matrix.stride()
+ # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution.
+ # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient.
+ matrix_strided = torch.as_strided(A_matrix, [batch_size, unstable_size, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[2], input_y + padding[0] + padding[1], 1])
+ # pieces have shape (unstable_size, batch, c, h, w).
+ first_indices = torch.arange(unstable_size, device=pieces.device)
+ matrix_strided[:,first_indices,unstable_idx[1],unstable_idx[2],:,:,:] = pieces.transpose(0, 1).to(matrix_strided)
+ A_matrix = A_matrix.view(batch_size, unstable_size, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1])
+
+ A_matrix = A_matrix[:,:,:,padding[2]:input_x + padding[2],padding[0]:input_y + padding[0]]
+
+ return A_matrix
+
+
+def check_patch_biases(lb, ub, lower_b, upper_b):
+ # When we use patches mode, it's possible that we need to add two bias
+ # one is from the Tensor mode and one is from the patches mode
+ # And we need to detect this case and reshape the bias
+ if lower_b.ndim < lb.ndim:
+ lb = lb.transpose(0,1).reshape(lb.size(1), lb.size(0), -1)
+ lb = lb.expand(lb.size(0), lb.size(1), lower_b.size(0)//lb.size(1))
+ lb = lb.reshape(lb.size(0), -1).t()
+ ub = ub.transpose(0,1).reshape(ub.size(1), ub.size(0), -1)
+ ub = ub.expand(ub.size(0), ub.size(1), upper_b.size(0)//ub.size(1))
+ ub = ub.reshape(ub.size(0), -1).t()
+ elif lower_b.ndim > lb.ndim:
+ lower_b = lower_b.transpose(0,1).reshape(lower_b.size(1), -1).t()
+ upper_b = upper_b.transpose(0,1).reshape(upper_b.size(1), -1).t()
+ return lb, ub, lower_b, upper_b
+
+
+def inplace_unfold(image, kernel_size, stride=1, padding=0, inserted_zeros=0, output_padding=0):
+ # Image has size (batch_size, channel, height, width).
+ assert image.ndim == 4
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+ if isinstance(padding, int):
+ padding = (padding, padding, padding, padding) # (left, right, top, bottom).
+ if len(padding) == 2: # (height direction, width direction).
+ padding = (padding[1], padding[1], padding[0], padding[0])
+ if isinstance(output_padding, int):
+ output_padding = (output_padding, output_padding, output_padding, output_padding) # (left, right, top, bottom).
+ if len(output_padding) == 2: # (height direction, width direction).
+ output_padding = (output_padding[1], output_padding[1], output_padding[0], output_padding[0])
+ if isinstance(stride, int):
+ stride = (stride, stride) # (height direction, width direction).
+ assert len(kernel_size) == 2 and len(padding) == 4 and len(stride) == 2
+ # Make sure the image is large enough for the kernel.
+ assert image.size(2) + padding[2] + padding[3] >= kernel_size[0] and image.size(3) + padding[0] + padding[1] >= kernel_size[1]
+ if inserted_zeros > 0:
+ # We first need to insert zeros in the image before unfolding.
+ image = insert_zeros(image, inserted_zeros)
+ # padding = (padding[0], padding[1] + 1, padding[2], padding[3] + 1)
+ # Compute the number of patches.
+ # Formulation: https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold
+ patches_h = int((image.size(2) + padding[2] + padding[3] - (kernel_size[0] - 1) - 1) / stride[0] + 1)
+ patches_w = int((image.size(3) + padding[0] + padding[1] - (kernel_size[1] - 1) - 1) / stride[1] + 1)
+ # Pad image.
+ if sum(padding) != 0:
+ image = torch.nn.functional.pad(image, padding)
+ # Save its orignal stride.
+ image_stride = image.stride()
+ matrix_strided = torch.as_strided(image, [
+ # Shape of the output matrix.
+ image.size(0), # Batch size.
+ patches_h, # indices for each patch.
+ patches_w,
+ image.size(1), # Channel.
+ kernel_size[0], # indices for each pixel on a patch.
+ kernel_size[1]], [
+ # Stride of the output matrix.
+ image_stride[0], # Batch size dimension, keep using the old stride.
+ image_stride[2] * stride[0], # Move patch in the height dimension.
+ image_stride[3] * stride[1], # Move patch in the width dimension.
+ image_stride[1], # Move to the next channel.
+ image_stride[2], # Move to the next row.
+ image_stride[3]]) # Move a pixel (on the width direction).
+ # Output shape is (batch_size, patches_h, patches_w, channel, kernel_height, kernel_width)
+ if sum(output_padding) > 0:
+ output_padding = tuple(p if p > 0 else None for p in output_padding)
+ output_padding = (output_padding)
+ matrix_strided = matrix_strided[:, output_padding[2]:-output_padding[3] if output_padding[3] is not None else None,
+ output_padding[0]:-output_padding[1] if output_padding[1] is not None else None, :, :, :]
+ return matrix_strided
+
+
+def maybe_unfold_patches(d_tensor, last_A, alpha_lookup_idx=None):
+ """
+ Utility function to handle patch mode bound propagation in activation functions.
+ In patches mode, we need to unfold lower and upper slopes (as input "d_tensor"). In matrix mode we simply return.
+ """
+ if d_tensor is None or last_A is None or isinstance(last_A, Tensor):
+ return d_tensor
+
+ # Shape for d_tensor:
+ # sparse: [spec, batch, in_c, in_h, in_w]
+ # non-sparse (partially shared): [out_c, batch, in_c, in_h, in_w]
+ # non-sparse (not shared): [out_c*out_h*out_w, batch, in_c, in_h, in_w]
+ # shared (independent of output spec): [1, batch, in_c, in_h, in_w]
+ # The in_h, in_w dimensions must be unfolded as patches.
+ origin_d_shape = d_tensor.shape
+ if d_tensor.ndim == 6:
+ # Merge the (out_h, out_w) dimensions.
+ d_tensor = d_tensor.view(*origin_d_shape[:2], -1, *origin_d_shape[-2:])
+ d_shape = d_tensor.size()
+ # Reshape to 4-D tensor to unfold.
+ d_tensor = d_tensor.view(-1, *d_tensor.shape[-3:])
+ # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W).
+ d_unfolded = inplace_unfold(
+ d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride,
+ padding=last_A.padding, inserted_zeros=last_A.inserted_zeros,
+ output_padding=last_A.output_padding)
+ # Reshape to the original shape of d, e.g., for non-sparse it is (out_c, batch, out_h, out_w, in_c, H, W).
+ d_unfolded_r = d_unfolded.view(*d_shape[:-3], *d_unfolded.shape[1:])
+ if last_A.unstable_idx is not None:
+ # Here we have d for all output neurons, but we only need to select unstable ones.
+ if d_unfolded_r.size(0) == 1:
+ # Shared alpha, spasre alpha should not be used.
+ assert alpha_lookup_idx is None
+ if len(last_A.unstable_idx) == 3:
+ # Broadcast the spec shape, so only need to select the rest dimensions.
+ # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W).
+ d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)
+ d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ elif len(last_A.unstable_idx) == 4:
+ # [spec, batch, output_h, output_w, input_c, H, W]
+ # to [output_h, output_w, batch, in_c, H, W]
+ d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)
+ d_unfolded_r = d_unfolded_r[last_A.unstable_idx[2], last_A.unstable_idx[3]]
+ else:
+ raise NotImplementedError()
+ # output shape: (unstable_size, batch, in_c, H, W).
+ else:
+ # The spec dimension may be sparse and contains unstable neurons for the spec layer only.
+ if alpha_lookup_idx is None:
+ # alpha is spec-dense. Possible because the number of unstable neurons may decrease.
+ if last_A.size(0) == d_unfolded_r.size(0):
+ # Non spec-sparse, partially shared alpha among output channel dimension.
+ # Shape after unfolding is (out_c, batch, out_h, out_w, in_c, patch_h, patch_w).
+ d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ else:
+ # Non spec-sparse, non-shared alpha.
+ # Shape after unfolding is (out_c*out_h*out_w, batch, out_h, out_w, in_c, patch_h, patch_w).
+ # Reshaped to (out_c, out_h, out_w, batch, out_h, out_w, in_c, patch_h, patch_w).
+ d_unfolded_r = d_unfolded_r.view(last_A.shape[0], last_A.shape[2], last_A.shape[3], -1, *d_unfolded_r.shape[2:])
+ # Select on all out_c, out_h, out_w dimensions.
+ d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], last_A.unstable_idx[1],
+ last_A.unstable_idx[2], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ elif alpha_lookup_idx.ndim == 1:
+ # sparse alpha: [spec, batch, in_c, in_h, in_w]
+ # Partially shared alpha on the spec dimension - all output neurons on the same channel use the same alpha.
+ # If alpha_lookup_idx is not None, we need to convert the sparse indices using alpha_lookup_idx.
+ _unstable_idx = alpha_lookup_idx[last_A.unstable_idx[0]]
+ # The selection is only used on the channel dimension.
+ d_unfolded_r = d_unfolded_r[_unstable_idx, :, last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ elif alpha_lookup_idx is not None and alpha_lookup_idx.ndim == 3:
+ # sparse alpha: [spec, batch, in_c, in_h, in_w]
+ # We created alpha as full output shape; alpha not shared among channel dimension.
+ # Shape of alpha is (out_c*out_h*out_w, batch, in_c, in_h, in_w), note that the first 3 dimensions
+ # is merged into one to allow simpler selection.
+ _unstable_idx = alpha_lookup_idx[
+ last_A.unstable_idx[0],
+ last_A.unstable_idx[1],
+ last_A.unstable_idx[2]]
+ # d_unfolded_r shape from (out_c, batch, out_h, out_w, in_c, in_h, in_w)
+ # to (out_c * out_h * out_w(sparse), batch, in_c, in_h, in_w)
+ # Note that the dimensions out_h, out_w come from unfolding, not specs in alpha, so they will be selected
+ # directly without translating using the lookup table.
+ d_unfolded_r = d_unfolded_r[_unstable_idx, :, last_A.unstable_idx[1], last_A.unstable_idx[2]]
+ # after selection we return (unstable_size, batch_size, in_c, H, W)
+ return d_unfolded_r
+ else:
+ raise ValueError
+ else:
+ # A is not sparse. Alpha shouldn't be sparse as well.
+ assert alpha_lookup_idx is None
+ if last_A.patches.size(0) != d_unfolded_r.size(0) and d_unfolded_r.size(0) != 1:
+ # Non-shared alpha, shape after unfolding is (out_c*out_h*out_w, batch, out_h, out_w, in_c, patch_h, patch_w).
+ # Reshaped to (out_c, out_h*out_w, batch, out_h*out_w, in_c, patch_h, patch_w).
+ d_unfolded_r = d_unfolded_r.reshape(last_A.shape[0], last_A.shape[2] * last_A.shape[3], -1,
+ d_unfolded_r.shape[2] * d_unfolded_r.shape[3], *d_unfolded_r.shape[4:])
+ # Select the "diagonal" elements in the out_h*out_w dimension.
+ # New shape is (out_c, batch, in_c, patch_h, patch_w, out_h*out_w)
+ d_unfolded_r = d_unfolded_r.diagonal(offset=0, dim1=1, dim2=3)
+ # New shape is (out_c, batch, in_c, patch_h, patch_w, out_h, out_w)
+ d_unfolded_r = d_unfolded_r.view(*d_unfolded_r.shape[:-1], last_A.shape[2], last_A.shape[3])
+ # New shape is (out_c, batch, out_h, out_w, in_c, patch_h, patch_w)
+ d_unfolded_r = d_unfolded_r.permute(0, 1, 5, 6, 2, 3, 4)
+
+
+ # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W).
+ # For regular patches, the shape after unfold is (out_c, batch, out_h, out_w, in_c, H, W).
+ if d_unfolded_r.ndim != last_A.patches.ndim:
+ # For the situation of d independent of output neuron (e.g., vanilla crown bound), which does not have
+ # the out_h, out_w dimension and out_c = 1 (sepc). We added 1s for the out_h, out_w dimensions.
+ d_unfolded_r = d_unfolded_r.unsqueeze(2).unsqueeze(-4)
+ return d_unfolded_r
diff --git a/auto_LiRPA/perturbations.py b/auto_LiRPA/perturbations.py
index c3fc386..e198388 100644
--- a/auto_LiRPA/perturbations.py
+++ b/auto_LiRPA/perturbations.py
@@ -1,18 +1,18 @@
-import os
import json
import math
import numpy as np
import torch
-import torch.nn as nn
-from .utils import logger, eyeC, LinearBound, Patches, BoundList, patches_to_matrix, inplace_unfold
-import torch.nn.functional as F
+from .utils import logger, eyeC
+from .patches import Patches, patches_to_matrix
+from .linear_bound import LinearBound
+
class Perturbation:
r"""
Base class for a perturbation specification. Please see examples
- at `auto_LiRPA/perturbations.py`.
+ at `auto_LiRPA/perturbations.py`.
- Examples:
+ Examples:
* `PerturbationLpNorm`: Lp-norm (p>=1) perturbation.
@@ -26,7 +26,7 @@ def __init__(self):
def set_eps(self, eps):
self.eps = eps
-
+
def concretize(self, x, A, sign=-1, aux=None):
r"""
Concretize bounds according to the perturbation specification.
@@ -54,7 +54,7 @@ def init(self, x, aux=None, forward=False):
aux (object, optional): Auxilary information.
- forward (bool): It indicates whether forward mode LiRPA is involved.
+ forward (bool): It indicates whether forward mode LiRPA is involved.
Returns:
bound (LinearBound): Initialized bounds.
@@ -69,16 +69,16 @@ def init(self, x, aux=None, forward=False):
"""Perturbation constrained by the L_0 norm (assuming input data is in the range of 0-1)."""
class PerturbationL0Norm(Perturbation):
- def __init__(self, eps, x_L = None, x_U = None, ratio = 1.0):
+ def __init__(self, eps, x_L=None, x_U=None, ratio=1.0):
self.eps = eps
self.x_U = x_U
self.x_L = x_L
self.ratio = ratio
- def concretize(self, x, A, sign = -1, aux = None):
+ def concretize(self, x, A, sign=-1, aux=None):
if A is None:
return None
-
+
eps = math.ceil(self.eps)
x = x.reshape(x.shape[0], -1, 1)
center = A.matmul(x)
@@ -89,7 +89,6 @@ def concretize(self, x, A, sign = -1, aux = None):
neg_mask = A < 0
pos_mask = A >= 0
-
if sign == 1:
A_diff = torch.zeros_like(A)
A_diff[pos_mask] = A[pos_mask] - original[pos_mask]# changes that one weight can contribute to the value
@@ -117,198 +116,193 @@ def init(self, x, aux=None, forward=False):
eye = torch.eye(dim).to(x.device).unsqueeze(0).repeat(batch_size, 1, 1)
lw = eye.reshape(batch_size, dim, *x.shape[1:])
lb = torch.zeros_like(x).to(x.device)
- uw, ub = lw.clone(), lb.clone()
+ uw, ub = lw.clone(), lb.clone()
return LinearBound(lw, lb, uw, ub, x_L, x_U), x, None
def __repr__(self):
return 'PerturbationLpNorm(norm=0, eps={})'.format(self.eps)
+
"""Perturbation constrained by the L_p norm."""
class PerturbationLpNorm(Perturbation):
- def __init__(self, eps=0, norm=np.inf, x_L=None, x_U=None, relative=False):
+ def __init__(self, eps=0, norm=np.inf, x_L=None, x_U=None):
self.eps = eps
self.norm = norm
self.dual_norm = 1 if (norm == np.inf) else (np.float64(1.0) / (1 - 1.0 / self.norm))
self.x_L = x_L
self.x_U = x_U
- self.relative = relative
+ self.sparse = False
- """Given an variable x and its bound matrix A, compute worst case bound according to Lp norm."""
- def concretize(self, x, A, sign=-1, aux=None, extra_constr=None):
- if A is None:
- return None
- # If A is an identity matrix, we will handle specially.
- def concretize_matrix(A):
- nonlocal x
- if not isinstance(A, eyeC):
- # A has (Batch, spec, *input_size). For intermediate neurons, spec is *neuron_size.
- A = A.reshape(A.shape[0], A.shape[1], -1)
+ def get_input_bounds(self, x, A):
+ if self.sparse:
+ if self.x_L_sparse.shape[-1] == A.shape[-1]:
+ x_L, x_U = self.x_L_sparse, self.x_U_sparse
+ else:
+ # In backward mode, A is not sparse
+ x_L, x_U = self.x_L, self.x_U
+ else:
+ x_L = x - self.eps if self.x_L is None else self.x_L
+ x_U = x + self.eps if self.x_U is None else self.x_U
+ return x_L, x_U
+
+ # If A is an identity matrix, we will handle specially.
+ def concretize_matrix(self, x, A, sign, extra_constr):
+ if not isinstance(A, eyeC):
+ # A has (Batch, spec, *input_size). For intermediate neurons, spec is *neuron_size.
+ A = A.reshape(A.shape[0], A.shape[1], -1)
+
+ if extra_constr is not None:
+ # For each neuron, we have a beta, so beta size is (Batch, *neuron_size, n_beta) (in A, spec is *neuron_size).
+ # For intermediate layer neurons, A has *neuron_size specifications.
+ beta = extra_constr['beta']
+ beta = beta.view(beta.size(0), -1, beta.size(-1))
+ # coeffs are linear relationships between split neurons and x. They have size (batch, n_beta, *input_size), and unreated to neuron_size.
+ beta_coeffs = extra_constr['coeffs']
+ beta_coeffs = beta_coeffs.view(beta_coeffs.size(0), beta_coeffs.size(1), -1)
+ # biases are added for each batch each spec, size is (batch, n_beta), and unrelated to neuron_size.
+ beta_bias = extra_constr['bias']
+ # Merge beta into extra A and bias. Extra A has size (batch, spec, *input_size). For intermediate neurons, spec is *neuron_size.
+ extra_A = torch.einsum('ijk,ikl->ijl', beta, beta_coeffs)
+ # Merge beta into the bias term. Output has size (batch, spec).
+ extra_bias = torch.einsum('ijk,ik->ij', beta, beta_bias)
+ if self.norm == np.inf:
+ # For Linfinity distortion, when an upper and lower bound is given, we use them instead of eps.
+ x_L, x_U = self.get_input_bounds(x, A)
+ x_ub = x_U.reshape(x_U.shape[0], -1, 1)
+ x_lb = x_L.reshape(x_L.shape[0], -1, 1)
+ # Find the uppwer and lower bound similarly to IBP.
+ center = (x_ub + x_lb) / 2.0
+ diff = (x_ub - x_lb) / 2.0
+ if not isinstance(A, eyeC):
if extra_constr is not None:
- # For each neuron, we have a beta, so beta size is (Batch, *neuron_size, n_beta) (in A, spec is *neuron_size).
- # For intermediate layer neurons, A has *neuron_size specifications.
- beta = extra_constr['beta']
- beta = beta.view(beta.size(0), -1, beta.size(-1))
- # coeffs are linear relationships between split neurons and x. They have size (batch, n_beta, *input_size), and unreated to neuron_size.
- beta_coeffs = extra_constr['coeffs']
- beta_coeffs = beta_coeffs.view(beta_coeffs.size(0), beta_coeffs.size(1), -1)
- # biases are added for each batch each spec, size is (batch, n_beta), and unrelated to neuron_size.
- beta_bias = extra_constr['bias']
- # Merge beta into extra A and bias. Extra A has size (batch, spec, *input_size). For intermediate neurons, spec is *neuron_size.
- extra_A = torch.einsum('ijk,ikl->ijl', beta, beta_coeffs)
- # Merge beta into the bias term. Output has size (batch, spec).
- extra_bias = torch.einsum('ijk,ik->ij', beta, beta_bias)
-
- if self.norm == np.inf:
- # For Linfinity distortion, when an upper and lower bound is given, we use them instead of eps.
- x_L = x - self.eps if self.x_L is None else self.x_L
- x_U = x + self.eps if self.x_U is None else self.x_U
- x_ub = x_U.reshape(x_U.shape[0], -1, 1)
- x_lb = x_L.reshape(x_L.shape[0], -1, 1)
- # Find the uppwer and lower bound similarly to IBP.
- center = (x_ub + x_lb) / 2.0
- diff = (x_ub - x_lb) / 2.0
- if not isinstance(A, eyeC):
- if extra_constr is not None:
- # Extra linear and bias terms from constraints.
- print(
- f'A extra: {(sign * extra_A).abs().sum().item()}, b extra: {(sign * extra_bias).abs().sum().item()}')
- A = A - sign * extra_A
- bound = A.matmul(center) - sign * extra_bias.unsqueeze(-1) + sign * A.abs().matmul(diff)
- else:
- bound = A.matmul(center) + sign * A.abs().matmul(diff)
+ # Extra linear and bias terms from constraints.
+ print(
+ f'A extra: {(sign * extra_A).abs().sum().item()}, '
+ f'b extra: {(sign * extra_bias).abs().sum().item()}')
+ A = A - sign * extra_A
+ bound = A.matmul(center) - sign * extra_bias.unsqueeze(-1) + sign * A.abs().matmul(diff)
else:
- assert extra_constr is None
- # A is an identity matrix. No need to do this matmul.
- bound = center + sign * diff
+ bound = A.matmul(center) + sign * A.abs().matmul(diff)
else:
assert extra_constr is None
- x = x.reshape(x.shape[0], -1, 1)
- if not isinstance(A, eyeC):
- # Find the upper and lower bounds via dual norm.
- deviation = A.norm(self.dual_norm, -1) * self.eps
- bound = A.matmul(x) + sign * deviation.unsqueeze(-1)
- else:
- # A is an identity matrix. Its norm is all 1.
- bound = x + sign * self.eps
- bound = bound.squeeze(-1)
- return bound
+ # A is an identity matrix. No need to do this matmul.
+ bound = center + sign * diff
+ else:
+ assert extra_constr is None
+ x = x.reshape(x.shape[0], -1, 1)
+ if not isinstance(A, eyeC):
+ # Find the upper and lower bounds via dual norm.
+ deviation = A.norm(self.dual_norm, -1) * self.eps
+ bound = A.matmul(x) + sign * deviation.unsqueeze(-1)
+ else:
+ # A is an identity matrix. Its norm is all 1.
+ bound = x + sign * self.eps
+ bound = bound.squeeze(-1)
+ return bound
- def concretize_patches(A):
- nonlocal x
- if self.norm == np.inf:
- # For Linfinity distortion, when an upper and lower bound is given, we use them instead of eps.
- x_L = x - self.eps if self.x_L is None else self.x_L
- x_U = x + self.eps if self.x_U is None else self.x_U
-
- # Here we should not reshape
- # Find the uppwer and lower bound similarly to IBP.
- center = (x_U + x_L) / 2.0
- diff = (x_U - x_L) / 2.0
- if not A.identity == 1:
- # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W] or [unstable_size, batch_size, in_c, H, W]. Here out_c is the spec dimension.
- patches = A.patches
-
- # unfold the input as [batch_size, out_h, out_w, in_c, H, W]
- unfold_input = inplace_unfold(center, kernel_size=A.patches.shape[-2:], padding = A.padding, stride = A.stride)
- if A.unstable_idx is not None:
- # We need to add a out_c dimension and select from it.
- unfold_input = unfold_input.unsqueeze(0).expand(A.output_shape[1], -1, -1, -1, -1, -1, -1)
- # When A is sparse, the shape is [unstable_size, batch_size, in_c, H, W]. Here unfold_input will match this shape.
- unfold_input = unfold_input[A.unstable_idx[0], :, A.unstable_idx[1], A.unstable_idx[2]]
- # size of bound: [batch_size, unstable_size].
- bound = torch.einsum('sbchw,sbchw->bs', unfold_input, patches)
- else:
- # size of bound: [batch_size, out_c, out_h, out_w].
- bound = torch.einsum('bijchw,sbijchw->bsij', unfold_input, patches)
-
- # unfold the diff as [batch_size, out_h, out_w, in_c, H, W]
- unfold_diff = inplace_unfold(diff, kernel_size=A.patches.shape[-2:], padding = A.padding, stride = A.stride)
- if A.unstable_idx is not None:
- # We need to add a out_c dimension and select from it.
- unfold_diff = unfold_diff.unsqueeze(0).expand(A.output_shape[1], -1, -1, -1, -1, -1, -1)
- # When A is sparse, the shape is [unstable_size, batch_size, in_c, H, W]
- unfold_diff = unfold_diff[A.unstable_idx[0], :, A.unstable_idx[1], A.unstable_idx[2]]
- # size of diff: [batch_size, unstable_size].
- bound_diff = torch.einsum('sbchw,sbchw->bs', unfold_diff, patches.abs())
- else:
- # size of diff: [batch_size, out_c, out_h, out_w].
- bound_diff = torch.einsum('bijchw,sbijchw->bsij', unfold_diff, patches.abs())
-
- if sign == 1:
- bound += bound_diff
- elif sign == -1:
- bound -= bound_diff
- else:
- raise ValueError("Unsupported Sign")
-
- # The extra bias term from beta term.
- if extra_constr is not None:
- bound += extra_constr
- else:
- assert extra_constr is None
- # A is an identity matrix. No need to do this matmul.
- bound = center + sign * diff
- return bound
- else: # Lp norm
- # x_L = x - self.eps if self.x_L is None else self.x_L
- # x_U = x + self.eps if self.x_U is None else self.x_U
-
- input_shape = x.shape
- if not A.identity:
- # Find the upper and lower bounds via dual norm.
- # matrix has shape (batch_size, out_c * out_h * out_w, input_c, input_h, input_w) or (batch_size, unstable_size, input_c, input_h, input_w)
- matrix = patches_to_matrix(A.patches, input_shape, A.stride, A.padding, A.output_shape, A.unstable_idx)
- # Note that we should avoid reshape the matrix. Due to padding, matrix cannot be reshaped without copying.
- deviation = matrix.norm(p=self.dual_norm, dim=(-3,-2,-1)) * self.eps
- # Bound has shape (batch, out_c * out_h * out_w) or (batch, unstable_size).
- bound = torch.einsum('bschw,bchw->bs', matrix, x) + sign * deviation
- if A.unstable_idx is None:
- # Reshape to (batch, out_c, out_h, out_w).
- bound = bound.view(matrix.size(0), A.patches.size(0), A.patches.size(2), A.patches.size(3))
+ def concretize_patches(self, x, A, sign, extra_constr):
+ if self.norm == np.inf:
+ x_L, x_U = self.get_input_bounds(x, A)
+
+ # Here we should not reshape
+ # Find the uppwer and lower bound similarly to IBP.
+ center = (x_U + x_L) / 2.0
+ diff = (x_U - x_L) / 2.0
+
+ if not A.identity == 1:
+ bound = A.matmul(center)
+ bound_diff = A.matmul(diff, patch_abs=True)
+
+ if sign == 1:
+ bound += bound_diff
+ elif sign == -1:
+ bound -= bound_diff
else:
- # A is an identity matrix. Its norm is all 1.
- bound = x + sign * self.eps
- return bound
+ raise ValueError("Unsupported Sign")
+
+ # The extra bias term from beta term.
+ if extra_constr is not None:
+ bound += extra_constr
+ else:
+ assert extra_constr is None
+ # A is an identity matrix. No need to do this matmul.
+ bound = center + sign * diff
+ return bound
+ else: # Lp norm
+ input_shape = x.shape
+ if not A.identity:
+ # Find the upper and lower bounds via dual norm.
+ # matrix has shape (batch_size, out_c * out_h * out_w, input_c, input_h, input_w) or (batch_size, unstable_size, input_c, input_h, input_w)
+ matrix = patches_to_matrix(A.patches, input_shape, A.stride, A.padding, A.output_shape, A.unstable_idx)
+ # Note that we should avoid reshape the matrix. Due to padding, matrix cannot be reshaped without copying.
+ deviation = matrix.norm(p=self.dual_norm, dim=(-3,-2,-1)) * self.eps
+ # Bound has shape (batch, out_c * out_h * out_w) or (batch, unstable_size).
+ bound = torch.einsum('bschw,bchw->bs', matrix, x) + sign * deviation
+ if A.unstable_idx is None:
+ # Reshape to (batch, out_c, out_h, out_w).
+ bound = bound.view(matrix.size(0), A.patches.size(0), A.patches.size(2), A.patches.size(3))
+ else:
+ # A is an identity matrix. Its norm is all 1.
+ bound = x + sign * self.eps
+ return bound
+ """Given an variable x and its bound matrix A, compute worst case bound according to Lp norm."""
+ def concretize(self, x, A, sign=-1, aux=None, extra_constr=None):
+ if A is None:
+ return None
if isinstance(A, eyeC) or isinstance(A, torch.Tensor):
- return concretize_matrix(A)
+ return self.concretize_matrix(x, A, sign, extra_constr)
elif isinstance(A, Patches):
- return concretize_patches(A)
- elif isinstance(A, BoundList):
- for b in A.bound_list:
- if isinstance(b, eyeC) or isinstance(b, torch.Tensor):
- pass
+ return self.concretize_patches(x, A, sign, extra_constr)
else:
raise NotImplementedError()
+ """ Sparse Linf perturbation where only a few dimensions are actually perturbed"""
+ def init_sparse_linf(self, x, x_L, x_U):
+ self.sparse = True
+ batch_size = x_L.shape[0]
+ perturbed = (x_U > x_L).int()
+ logger.debug(f'Perturbed: {perturbed.sum()}')
+ lb = ub = x_L * (1 - perturbed) # x_L=x_U holds when perturbed=0
+ perturbed = perturbed.view(batch_size, -1)
+ index = torch.cumsum(perturbed, dim=-1)
+ dim = max(perturbed.view(batch_size, -1).sum(dim=-1).max(), 1)
+ self.x_L_sparse = torch.zeros(batch_size, dim + 1).to(x_L)
+ self.x_L_sparse.scatter_(dim=-1, index=index, src=(x_L - lb).view(batch_size, -1), reduce='add')
+ self.x_U_sparse = torch.zeros(batch_size, dim + 1).to(x_U)
+ self.x_U_sparse.scatter_(dim=-1, index=index, src=(x_U - ub).view(batch_size, -1), reduce='add')
+ self.x_L_sparse, self.x_U_sparse = self.x_L_sparse[:, 1:], self.x_U_sparse[:, 1:]
+ lw = torch.zeros(batch_size, dim + 1, perturbed.shape[-1], device=x.device)
+ perturbed = perturbed.to(torch.get_default_dtype())
+ lw.scatter_(dim=1, index=index.unsqueeze(1), src=perturbed.unsqueeze(1))
+ lw = uw = lw[:, 1:, :].view(batch_size, dim, *x.shape[1:])
+ print(f'Using Linf sparse perturbation. Perturbed dimensions: {dim}.')
+ print(f'Avg perturbation: {(self.x_U_sparse - self.x_L_sparse).mean()}')
+ return LinearBound(
+ lw, lb, uw, ub, x_L, x_U), x, None
+
def init(self, x, aux=None, forward=False):
+ self.sparse = False
if self.norm == np.inf:
x_L = x - self.eps if self.x_L is None else self.x_L
x_U = x + self.eps if self.x_U is None else self.x_U
else:
# For other norms, we pass in the BoundedTensor objects directly.
- x_L = x
- x_U = x
- if self.relative:
- nominal = x
- lower_offset = torch.max(x_L - x - 1e-8, torch.ones_like(x_L) * (-self.eps))
- upper_offset = torch.min(x_U - x + 1e-8, torch.ones_like(x_U) * (self.eps))
- else:
- nominal = lower_offset = upper_offset = None
+ x_L = x_U = x
if not forward:
return LinearBound(
- None, None, None, None, x_L, x_U,
- nominal=nominal, lower_offset=lower_offset, upper_offset=upper_offset), x, None
+ None, None, None, None, x_L, x_U), x, None
+ if self.norm == np.inf and x_L.numel() > 1 and (x_L == x_U).sum() > 0.5 * x_L.numel():
+ return self.init_sparse_linf(x, x_L, x_U)
+
batch_size = x.shape[0]
dim = x.reshape(batch_size, -1).shape[-1]
+ lb = ub = torch.zeros_like(x)
eye = torch.eye(dim).to(x).expand(batch_size, dim, dim)
- lw = eye.reshape(batch_size, dim, *x.shape[1:])
- lb = torch.zeros_like(x).to(x.device)
- uw, ub = lw.clone(), lb.clone()
+ lw = uw = eye.reshape(batch_size, dim, *x.shape[1:])
return LinearBound(
- lw, lb, uw, ub, x_L, x_U,
- nominal=nominal, lower_offset=lower_offset, upper_offset=upper_offset), x, None
+ lw, lb, uw, ub, x_L, x_U), x, None
def __repr__(self):
if self.norm == np.inf:
@@ -354,7 +348,7 @@ def concretize(self, x, A, sign, aux):
num_pos = int(np.max(np.sum(can_be_replaced, axis=-1)))
update_A = A.shape[-1] > num_pos * dim_word
if update_A:
- bias = torch.bmm(A, (x * (1 - mask_rep).unsqueeze(-1)).reshape(batch_size, -1, 1)).squeeze(-1)
+ bias = torch.bmm(A, (x * (1 - mask_rep).unsqueeze(-1)).reshape(batch_size, -1, 1)).squeeze(-1)
else:
bias = 0.
A = A.reshape(batch_size, dim_out, -1, dim_word)
@@ -386,7 +380,7 @@ def concretize(self, x, A, sign, aux):
mask = torch.cat(mask_new).reshape(batch_size, num_pos, max_num_cand)
length = num_pos
- A = A.reshape(batch_size, A.shape[1], length, -1).transpose(1, 2)
+ A = A.reshape(batch_size, A.shape[1], length, -1).transpose(1, 2)
x = x.reshape(batch_size, length, -1, 1)
if sign == 1:
@@ -396,8 +390,8 @@ def concretize(self, x, A, sign, aux):
init_tensor = torch.ones(batch_size, dim_out).to(x.device) * init
dp = [[init_tensor] * (self.budget + 1) for i in range(0, length + 1)]
- dp[0][0] = torch.zeros(batch_size, dim_out).to(x.device)
-
+ dp[0][0] = torch.zeros(batch_size, dim_out).to(x.device)
+
A = A.reshape(batch_size * length, A.shape[2], A.shape[3])
Ax = torch.bmm(
A,
@@ -418,10 +412,10 @@ def concretize(self, x, A, sign, aux):
dp[i][0] = dp[i - 1][0] + Ax[:, i - 1]
for j in range(1, self.budget + 1):
dp[i][j] = cmp(
- dp[i - 1][j] + Ax[:, i - 1],
+ dp[i - 1][j] + Ax[:, i - 1],
dp[i - 1][j - 1] + Ax_rep_bound[:, i - 1]
)
- dp = torch.cat(dp[length], dim=0).reshape(self.budget + 1, batch_size, dim_out)
+ dp = torch.cat(dp[length], dim=0).reshape(self.budget + 1, batch_size, dim_out)
return cmp(dp, dim=0).values + bias
@@ -432,7 +426,7 @@ def init(self, x, aux=None, forward=False):
batch_size, length, dim_word = x.shape[0], x.shape[1], x.shape[2]
max_pos = 1
- can_be_replaced = np.zeros((batch_size, length), dtype=bool)
+ can_be_replaced = np.zeros((batch_size, length), dtype=np.bool)
self._build_substitution(batch)
@@ -457,8 +451,8 @@ def init(self, x, aux=None, forward=False):
if forward:
eye = torch.eye(dim_word).to(x.device)
lw = torch.zeros(batch_size, dim, length, dim_word).to(x.device)
- lb = torch.zeros_like(x).to(x.device)
- x_new = []
+ lb = torch.zeros_like(x).to(x.device)
+ x_new = []
word_embeddings = self.model.word_embeddings.weight
vocab = self.model.vocab
x_rep = [[[] for i in range(length)] for t in range(batch_size)]
@@ -467,8 +461,8 @@ def init(self, x, aux=None, forward=False):
candidates = batch[t]['candidates']
# for transformers
if tokens[t][0] == '[CLS]':
- candidates = [[]] + candidates + [[]]
- cnt = 0
+ candidates = [[]] + candidates + [[]]
+ cnt = 0
for i in range(length):
if can_be_replaced[t][i]:
word_embed = word_embeddings[vocab[tokens[t][i]]]
@@ -485,13 +479,13 @@ def init(self, x, aux=None, forward=False):
cnt += 1
else:
if forward:
- lb[t, i, :] = x[t, i, :]
+ lb[t, i, :] = x[t, i, :]
if forward:
uw, ub = lw, lb
else:
lw = lb = uw = ub = None
zeros = torch.zeros(dim_word, device=x.device)
-
+
x_rep_, mask = [], []
for t in range(batch_size):
for i in range(length):
@@ -501,7 +495,7 @@ def init(self, x, aux=None, forward=False):
mask = torch.tensor(mask, dtype=torch.float32, device=x.device)\
.reshape(batch_size, length, max_num_cand)
x_rep_ = x_rep_ * self.eps + x.unsqueeze(2) * (1 - self.eps)
-
+
inf = 1e20
lower = torch.min(mask.unsqueeze(-1) * x_rep_ + (1 - mask).unsqueeze(-1) * inf, dim=2).values
upper = torch.max(mask.unsqueeze(-1) * x_rep_ + (1 - mask).unsqueeze(-1) * (-inf), dim=2).values
diff --git a/auto_LiRPA/solver_module.py b/auto_LiRPA/solver_module.py
new file mode 100644
index 0000000..9720ce5
--- /dev/null
+++ b/auto_LiRPA/solver_module.py
@@ -0,0 +1,132 @@
+import multiprocessing
+import multiprocessing.pool
+import sys
+import os
+
+import torch
+from .bound_ops import *
+
+
+def build_solver_module(self, x=None, C=None, intermediate_layer_bounds=None,
+ final_node_name=None, model_type="mip", solver_pkg="gurobi"):
+ r"""build lp/mip solvers in general graph.
+
+ Args:
+ x: inputs, a list of BoundedTensor. If set to None, we reuse exisint bounds that
+ were previously computed in compute_bounds().
+ C (Tensor): The specification matrix that can map the output of the model with an
+ additional linear layer. This is usually used for maping the logits output of the
+ model to classification margins.
+ intermediate_layer_bounds: if specified, will replace existing intermediate layer bounds.
+ Otherwise we reuse exising intermediate bounds.
+
+ final_node_name (String): the name for the target layer to optimize
+
+ solver_pkg (String): the backbone of the solver, default gurobi, also support scipy
+
+ Returns:
+ output vars (list): a list of final nodes to optimize
+ """
+ # self.root_name: list of root node name
+ # self.final_name: list of output node name
+ # self.final_node: output module
+ # .input: a list of input modules of this layer module
+ # .solver_vars: a list of gurobi vars of every layer module
+ # list with conv shape if conv layers, otherwise flattened
+ # if last layer we need to be careful with:
+ # C: specification matrix
+ # .is_input_perturbed(1)
+
+ if x is not None:
+ assert intermediate_layer_bounds is not None
+ # Set the model to use new intermediate layer bounds, ignore the original ones.
+ self._set_input(x, intermediate_layer_bounds=intermediate_layer_bounds)
+
+ root = [self[name] for name in self.root_name]
+
+ # create interval ranges for input and other weight parameters
+ for i in range(len(root)):
+ value = root[i].forward()
+ # if isinstance(root[i], BoundInput) and not isinstance(root[i], BoundParams):
+ if type(root[i]) is BoundInput:
+ # create input vars for gurobi self.model
+ inp_gurobi_vars = self._build_solver_input(root[i])
+ else:
+ # regular weights
+ root[i].solver_vars = value
+
+ final = self.final_node() if final_node_name is None else self[final_node_name]
+
+ # backward propagate every layer including last layer
+ self._build_solver_general(node=final, C=C, model_type=model_type, solver_pkg=solver_pkg)
+
+ # a list of output solver vars
+ return final.solver_vars
+
+
+def _build_solver_general(self, node, C=None, model_type="mip", solver_pkg="gurobi"):
+ if not hasattr(node, 'solver_vars'):
+ for n in node.inputs:
+ self._build_solver_general(n, C=C, model_type=model_type, solver_pkg=solver_pkg)
+ inp = [n_pre.solver_vars for n_pre in node.inputs]
+ # print(node, node.inputs)
+ if C is not None and isinstance(node, BoundLinear) and\
+ not node.is_input_perturbed(1) and self.final_name == node.name:
+ # when node is the last layer
+ # merge the last BoundLinear node with the specification,
+ # available when weights of this layer are not perturbed
+ solver_vars = node.build_solver(*inp, model=self.model, C=C,
+ model_type=model_type, solver_pkg=solver_pkg)
+ else:
+ solver_vars = node.build_solver(*inp, model=self.model, C=None,
+ model_type=model_type, solver_pkg=solver_pkg)
+ # just return output node gurobi vars
+ return solver_vars
+
+
+def _build_solver_input(self, node):
+ ## Do the input layer, which is a special case
+ assert isinstance(node, BoundInput)
+ assert node.perturbation is not None
+ assert node.perturbation.norm == float("inf")
+ inp_gurobi_vars = []
+ # zero var will be shared within the solver model
+ zero_var = self.model.addVar(lb=0, ub=0, obj=0, vtype=grb.GRB.CONTINUOUS, name='zero')
+ x_L = node.value - node.perturbation.eps if node.perturbation.x_L is None else node.perturbation.x_L
+ x_U = node.value + node.perturbation.eps if node.perturbation.x_U is None else node.perturbation.x_U
+ x_L = x_L.squeeze(0)
+ x_U = x_U.squeeze(0)
+ # x_L, x_U = node.lower.squeeze(0), node.upper.squeeze(0)
+
+ if x_L.ndim == 1:
+ # This is a linear input.
+ for dim, (lb, ub) in enumerate(zip(x_L, x_U)):
+ v = self.model.addVar(lb=lb, ub=ub, obj=0,
+ vtype=grb.GRB.CONTINUOUS,
+ name=f'inp_{dim}')
+ inp_gurobi_vars.append(v)
+ else:
+ assert x_L.ndim == 3, f"x_L ndim {x_L.ndim}"
+ dim = 0
+ for chan in range(x_L.shape[0]):
+ chan_vars = []
+ for row in range(x_L.shape[1]):
+ row_vars = []
+ for col in range(x_L.shape[2]):
+ lb = x_L[chan, row, col]
+ ub = x_U[chan, row, col]
+ v = self.model.addVar(lb=lb, ub=ub, obj=0,
+ vtype=grb.GRB.CONTINUOUS,
+ name=f'inp_{dim}')
+ # name=f'inp_[{chan},{row},{col}]')
+ row_vars.append(v)
+ dim += 1
+ chan_vars.append(row_vars)
+ inp_gurobi_vars.append(chan_vars)
+
+ node.solver_vars = inp_gurobi_vars
+ # save the gurobi input variables so that we can later extract primal values in input space easily
+ self.input_vars = inp_gurobi_vars
+ self.model.update()
+ return inp_gurobi_vars
+
diff --git a/auto_LiRPA/utils.py b/auto_LiRPA/utils.py
index a6c4280..e50ea0c 100644
--- a/auto_LiRPA/utils.py
+++ b/auto_LiRPA/utils.py
@@ -3,15 +3,18 @@
import time
import torch
import torch.nn as nn
+import torch.nn.functional as F
import os
import sys
import appdirs
-from collections import defaultdict, Sequence, namedtuple
+from collections import defaultdict, namedtuple
+from collections.abc import Sequence
from functools import reduce
import operator
import math
-import torch.nn.functional as F
import warnings
+from typing import Tuple
+from .patches import Patches, insert_zeros
logging.basicConfig(
format='%(levelname)-8s %(asctime)-12s %(message)s',
@@ -19,7 +22,7 @@
stream=sys.stdout
)
logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
+logger.setLevel(logging.DEBUG if os.environ.get('AUTOLIRPA_DEBUG', 0) else logging.INFO)
warnings.simplefilter("once")
@@ -35,6 +38,24 @@
reduction_max = lambda x: x.max(1, keepdim=True).values
reduction_min = lambda x: x.min(1, keepdim=True).values
+MIN_HALF_FP = 5e-8 # 2**-24, which is the smallest value that float16 can be represented
+
+
+def reduction_str2func(reduction_func):
+ if type(reduction_func) == str:
+ if reduction_func == 'min':
+ return reduction_min
+ elif reduction_func == 'max':
+ return reduction_max
+ elif reduction_func == 'sum':
+ return reduction_sum
+ elif reduction_func == 'mean':
+ return reduction_mean
+ else:
+ raise NotImplementedError(f'Unknown reduction_func {reduction_func}')
+ else:
+ return reduction_func
+
def stop_criterion_sum(threshold=0):
return lambda x: (x.sum(1, keepdim=True) > threshold)
@@ -47,34 +68,22 @@ def stop_criterion_min(threshold=0):
def stop_criterion_max(threshold=0):
return lambda x: (x.max(1, keepdim=True).values > threshold)
-# Create a namedtuple with defaults
-def namedtuple_with_defaults(name, attr, defaults):
- assert sys.version_info.major == 3
- if sys.version_info.major >= 7:
- return namedtuple(name, attr, defaults=defaults)
- else:
- # The defaults argument is not available in Python < 3.7
- t = namedtuple(name, attr)
- t.__new__.__defaults__ = defaults
- return t
-
-# A special class which denotes a convoluntional operator as a group of patches
-# the shape of Patches.patches is [batch_size, num_of_patches, out_channel, in_channel, M, M]
-# M is the size of a single patch
-# Assume that we have a conv2D layer with w.weight(out_channel, in_channel, M, M), stride and padding applied on an image (N * N)
-# num_of_patches = ((N + padding * 2 - M)//stride + 1) ** 2
-# Here we only consider kernels with the same H and W
-Patches = namedtuple_with_defaults('Patches', ('patches', 'stride', 'padding', 'shape', 'identity', 'unstable_idx', 'output_shape'), defaults=(None, 1, 0, None, 0, None, None))
-BoundList = namedtuple_with_defaults('BoundList', ('bound_list'), defaults=([],))
-# Linear bounds with coefficients. Used for forward bound propagation.
-LinearBound = namedtuple_with_defaults('LinearBound', ('lw', 'lb', 'uw', 'ub', 'lower', 'upper', 'from_input', 'nominal', 'lower_offset', 'upper_offset'), defaults=(None,) * 10)
-
-# for debugging
-if False:
- file_handler = logging.FileHandler('debug.log')
- file_handler.setFormatter(logging.Formatter('%(levelname)-8s %(asctime)-12s %(message)s'))
- logger.addHandler(file_handler)
- logger.setLevel(logging.DEBUG)
+def stop_criterion_batch(threshold=0):
+ # may unexpected broadcast, pay attention to the shape of threshold
+ # x shape: batch, number_bounds; threshold shape: batch, number_bounds
+ # print('threshold', threshold.shape)
+ return lambda x: (x > threshold)
+
+def stop_criterion_batch_any(threshold=0):
+ # may unexpected broadcast, pay attention to the shape of threshold
+ # x shape: batch, number_bounds; threshold shape: batch, number_bounds
+ # print('threshold', threshold.shape)
+ return lambda x: (x > threshold).any(dim=1)
+
+def stop_criterion_batch_topk(threshold=0, k=1314):
+ # x shape: batch, number_bounds; threshold shape: batch, number_bounds
+ # print('threshold', threshold.shape)
+ return lambda x: (torch.kthvalue(x, k, dim=-1, keepdim=True).values > threshold).any(dim=1)
user_data_dir = appdirs.user_data_dir('auto_LiRPA')
if not os.path.exists(user_data_dir):
@@ -171,7 +180,7 @@ def scale_gradients(optimizer, gradient_accumulation_steps, grad_clip=None):
if param.grad is not None:
param.grad.data /= gradient_accumulation_steps
if grad_clip is not None:
- torch.nn.utils.clip_grad_norm_(parameters, grad_clip)
+ return torch.nn.utils.clip_grad_norm_(parameters, grad_clip)
def recursive_map (seq, func):
for item in seq:
@@ -224,55 +233,6 @@ def batched_index_select(input, dim, index):
return torch.gather(input, dim, index)
-"""Converting a Patches piece into a full dense matrix."""
-def patches_to_matrix(pieces, input_shape, stride, padding, output_shape=None, unstable_idx=None):
- if type(padding) == int:
- padding = (padding, padding, padding, padding)
- if output_shape is None:
- assert pieces.ndim == 7
- # Non-sparse pieces, with shape (out_c, batch, out_h, out_w, c, h, w).
- output_channel, batch_size, output_x, output_y = pieces.shape[:4]
- else:
- batch_size, output_channel, output_x, output_y = output_shape
- input_channel, kernel_x, kernel_y = pieces.shape[-3:]
- input_x, input_y = input_shape[-2:]
-
- if unstable_idx is None:
- # Fix all patches in a full A matrix.
- A_matrix = torch.zeros(batch_size, output_channel, output_x, output_y, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device)
- # Save its orignal stride.
- orig_stride = A_matrix.stride()
- # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution.
- # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient.
- matrix_strided = torch.as_strided(A_matrix, [batch_size, output_channel, output_x, output_y, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], orig_stride[2], orig_stride[3], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[4], input_y + padding[0] + padding[1], 1])
- # Now we need to fill the conv kernel parameters into the last three dimensions of matrix_strided.
- first_indices = torch.arange(output_x * output_y, device=pieces.device)
- second_indices = torch.div(first_indices, output_y, rounding_mode="trunc")
- third_indices = torch.fmod(first_indices, output_y)
- # pieces have shape (out_c, batch, out_h, out_w, c, h, w).
- pieces = pieces.transpose(0, 1) # pieces has the out_c dimension at the front, need to move it to the second.
- matrix_strided[:,:,second_indices,third_indices,second_indices,third_indices,:,:,:] = pieces.reshape(*pieces.shape[:2], -1, *pieces.shape[4:])
- A_matrix = A_matrix.view(batch_size, output_channel * output_x * output_y, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1])
- else:
- # Fill only a selection of patches.
- # Create only a partial A matrix.
- unstable_size = unstable_idx[0].numel()
- A_matrix = torch.zeros(batch_size, unstable_size, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device)
- # Save its orignal stride.
- orig_stride = A_matrix.stride()
- # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution.
- # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient.
- matrix_strided = torch.as_strided(A_matrix, [batch_size, unstable_size, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[2], input_y + padding[0] + padding[1], 1])
- # pieces have shape (unstable_size, batch, c, h, w).
- first_indices = torch.arange(unstable_size, device=pieces.device)
- matrix_strided[:,first_indices,unstable_idx[1],unstable_idx[2],:,:,:] = pieces.transpose(0,1)
- A_matrix = A_matrix.view(batch_size, unstable_size, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1])
-
- A_matrix = A_matrix[:,:,:,padding[2]:input_x + padding[2],padding[0]:input_y + padding[0]]
-
- return A_matrix
-
-
def check_padding(x, padding):
if isinstance(padding, int):
return x, (padding, padding)
@@ -283,47 +243,6 @@ def check_padding(x, padding):
return F.pad(x, padding), (0, 0)
-def inplace_unfold(image, kernel_size, stride=1, padding=0):
- # Image has size (batch_size, channel, height, width).
- assert image.ndim == 4
- if isinstance(kernel_size, int):
- kernel_size = (kernel_size, kernel_size)
- if isinstance(padding, int):
- padding = (padding, padding, padding, padding) # (left, right, top, bottom).
- if len(padding) == 2: # (height direction, width direction).
- padding = (padding[1], padding[1], padding[0], padding[0])
- if isinstance(stride, int):
- stride = (stride, stride) # (height direction, width direction).
- assert len(kernel_size) == 2 and len(padding) == 4 and len(stride) == 2
- # Make sure the image is large enough for the kernel.
- assert image.size(2) + padding[2] + padding[3] >= kernel_size[0] and image.size(3) + padding[0] + padding[1] >= kernel_size[1]
- # Compute the number of patches.
- # Formulation: https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold
- patches_h = int((image.size(2) + padding[2] + padding[3] - (kernel_size[0] - 1) - 1) / stride[0] + 1)
- patches_w = int((image.size(3) + padding[0] + padding[1] - (kernel_size[1] - 1) - 1) / stride[1] + 1)
- # Pad image.
- if sum(padding) != 0:
- image = torch.nn.functional.pad(image, padding)
- # Save its orignal stride.
- image_stride = image.stride()
- matrix_strided = torch.as_strided(image, [
- # Shape of the output matrix.
- image.size(0), # Batch size.
- patches_h, # indices for each patch.
- patches_w,
- image.size(1), # Channel.
- kernel_size[0], # indices for each pixel on a patch.
- kernel_size[1]], [
- # Stride of the output matrix.
- image_stride[0], # Batch size dimension, keep using the old stride.
- image_stride[2] * stride[0], # Move patch in the height dimension.
- image_stride[3] * stride[1], # Move patch in the width dimension.
- image_stride[1], # Move to the next channel.
- image_stride[2], # Move to the next row.
- image_stride[3]]) # Move a pixel (on the width direction).
- # Output shape is (batch_size, patches_h, patches_w, channel, kernel_height, kernel_width)
- return matrix_strided
-
def get_spec_matrix(X, y, num_classes):
with torch.no_grad():
c = (torch.eye(num_classes).type_as(X)[y].unsqueeze(1)
@@ -331,3 +250,40 @@ def get_spec_matrix(X, y, num_classes):
I = (~(y.unsqueeze(1) == torch.arange(num_classes).type_as(y).unsqueeze(0)))
c = (c[I].view(X.size(0), num_classes - 1, num_classes))
return c
+
+def unravel_index(
+ indices: torch.LongTensor,
+ shape: Tuple[int, ...],
+) -> torch.LongTensor:
+ r"""Converts flat indices into unraveled coordinates in a target shape.
+
+ Args:
+ indices: A tensor of (flat) indices, (*, N).
+ shape: The targeted shape, (D,).
+
+ Returns:
+ The unraveled coordinates, a list with tensors in shape (N, D).
+
+ Code borrowed from:
+ https://github.com/pytorch/pytorch/issues/35674
+ """
+
+ coord = []
+
+ for dim in reversed(shape):
+ coord.append(indices % dim)
+ indices = torch.div(indices, dim, rounding_mode='trunc')
+
+ return list(reversed(coord))
+
+def get_A_shape(A):
+ if A is None:
+ return 'None'
+ if isinstance(A, Patches):
+ if A.patches is not None:
+ return A.patches.shape
+ else:
+ return A.shape
+ if isinstance(A, torch.Tensor):
+ return A.shape
+ return 'Unknown'
diff --git a/doc/README.md b/doc/README.md
index 48a25fd..c4b102b 100644
--- a/doc/README.md
+++ b/doc/README.md
@@ -1,7 +1,7 @@
-
# Documentation
This directory contains source files for building our documentation.
+Please view the compiled documentation on our [documentation page](https://auto-lirpa.readthedocs.io/en/latest/?badge=latest), as some links may not work here on GitHub.
## Dependencies
diff --git a/doc/api.rst b/doc/api.rst
index f70bf4b..5375cb5 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -18,8 +18,6 @@ API Usage
.. autofunction:: auto_LiRPA.perturbations.Perturbation.concretize
.. autofunction:: auto_LiRPA.perturbations.Perturbation.init
-.. autofunction:: auto_LiRPA.bound_op_map.register_custom_op
-
Indices and tables
-------------------
diff --git a/doc/process.py b/doc/process.py
index 93ffc48..6ad1116 100644
--- a/doc/process.py
+++ b/doc/process.py
@@ -5,7 +5,7 @@
from pygit2 import Repository
repo = 'https://github.com/KaidiXu/auto_LiRPA'
-branch = os.environ.get('BRANCH', None) or Repository('.').head.shorthand
+branch = Repository('.').head.shorthand
repo_file_path = os.path.join(repo, 'tree', branch)
""" Parse README.md into sections which can be reused """
@@ -40,7 +40,6 @@
source = file.read()
source_new = ''
ptr = 0
- # res = re.findall('\[.*\]\(.*\)', source)
for m in re.finditer('(\[.*\])(\(.*\))', source):
assert m.start() >= ptr
source_new += source[ptr:m.start()]
diff --git a/doc/src/bound_opts.md b/doc/src/bound_opts.md
index 4fc9f9c..3d21b72 100644
--- a/doc/src/bound_opts.md
+++ b/doc/src/bound_opts.md
@@ -1,34 +1,41 @@
Bound Options
====================
-Bound options can be set by passing a dictionary to the `bound_opts` argument for `BoundedModule`.
+Bound options can be set by passing a dictionary to the `bound_opts` argument for `BoundedModule`.
This page lists available bound options.
## Arguments for Optimizing Bounds (`optimize_bound_args`)
Arguments for optimizing bounds with the `CROWN-Optimized` method can be provided as a dictionary. Available arguments include:
-* `ob_alpha` (bool, default `True`): Enable α-CROWN (optimized CROWN/LiRPA).
+* `enable_alpha_crown` (bool, default `True`): Enable α-CROWN (optimized CROWN/LiRPA).
-* `ob_beta` (bool, default `False`): Enable β-CROWN.
+* `enable_beta_crown` (bool, default `False`): Enable β-CROWN.
-* `ob_optimizer` (str, default `adam`): Optimzier. Set it to `adam-autolr` to use `AdamElementLR`, or `sgd` to use SGD.
+* `optimizer` (str, default `adam`): Optimzier. Set it to `adam-autolr` to use `AdamElementLR`, or `sgd` to use SGD.
-* `ob_verbose` (int, default 0): If greater than 1, print verbosely.
+* `lr_alpha` (float, default 0.5), `lr_beta` (default 0.05): Learning rates for α and β parameters in α-CROWN and β-CROWN.
-* `ob_lr` (default 0.5), `ob_lr_beta` (default 0.05): Learning rates for α and β parameters in α-CROWN and β-CROWN.
+* `lr_decay` (float, default 0.98): Learning rate decay factor for the `ExponentialLR` scheduler.
-* `ob_lr_decay` (default 0.98): Learning rate decay factor for the `ExponentialLR` scheduler.
+* `iteration` (int): Number of optimization iterations.
-* `ob_iteration` (int): Number of optimization iterations.
+* `loss_reduction_func` (function): Function for loss reduction over the specification dimension. By default, use `auto_LiRPA.utils.reduction_sum` which sumes the bound over all batch elements and specifications.
-* `ob_loss_reduction_func` (function): Function for loss reduction over the specification dimension. By default, use `auto_LiRPA.utils.reduction_sum` which sumes the bound over all batch elements and specifications.
+* `stop_criterion_func` (function): Function for the criterion of stopping optimization early; it returns a tensor of `torch.bool` with `batch_size` elements. By default, it is a lambda function that always returns `False` . Several pre-defined options are `auto_LiRPA.utils.stop_criterion_min`, `auto_LiRPA.utils.stop_criterion_mean`, `auto_LiRPA.utils.stop_criterion_max` and `auto_LiRPA.utils.stop_criterion_sum`. For example, `auto_LiRPA.utils.stop_criterion_min` checks the minimum bound over all specifications of a batch element and returns `True` for that element when the minimum bound is greater than a specified threshold.
-* `ob_stop_criterion_func` (function): Function for the criterion of stopping optimization early; it returns a tensor of `torch.bool` with `batch_size` elements. By default, it is a lambda function that always returns `False` . Several pre-defined options are `auto_LiRPA.utils.stop_criterion_min`, `auto_LiRPA.utils.stop_criterion_mean`, `auto_LiRPA.utils.stop_criterion_max` and `auto_LiRPA.utils.stop_criterion_sum`. For example, `auto_LiRPA.utils.stop_criterion_min` checks the minimum bound over all specifications of a batch element and returns `True` for that element when the minimum bound is greater than a specified threshold.
+* `keep_best` (bool, default `True`): If `True`, save α, β and bounds at the best iteration. Otherwise the last iteration result is used.
-* `ob_keep_best` (bool, default `True`): If `True`, save α, β parameters at the best iteration. Otherwise the last iteration result is used.
+* `use_shared_alpha` (bool, default `False`): If `True`, all intermediate neurons from the same layer share the same set of α variables during bound optimization. For a very large model, enabling this option can save memory, at a cost of slightly looser bound.
+
+* `fix_intermediate_layer_bounds` (bool, default `True`): Only optimize bounds of last layer during alpha/beta CROWN.
+
+* `init_alpha` (bool, default `True`): Initial alpha variables by calling CROWN once.
+
+* `early_stop_patience` (int, default, 10): Number of iterations that we will start considering early stop if tracking no improvement.
+
+* `start_save_best` (float, default 0.5): Start to save optimized best bounds when current_iteration > int(iteration*start_save_best)
-* `ob_alpha_share_slopes` (bool, default `False`): If `True`, all intermediate neurons from the same layer share the same set of α variables during bound optimization. For a very large model, enabling this option can save memory, at a cost of slightly looser bound.
## ReLU (`relu`):
diff --git a/doc/src/custom_op.md b/doc/src/custom_op.md
index 853aa48..cc62038 100644
--- a/doc/src/custom_op.md
+++ b/doc/src/custom_op.md
@@ -13,7 +13,7 @@ There are three steps to write an operator:
3. Implement a [Bound class](api.html#auto_LiRPA.bound_ops.Bound) to support bound propagation methods for this operator.
-4. [Register the custom operator](api.html#auto_LiRPA.bound_op_map.register_custom_op).
+4. Create a mapping from the operator name (defined in step 1) to the bound class (defined in step 3). Define a `dict` which each item is a mapping. Pass the `dict` to the `custom_ops` argument when calling `BoundedModule` (see the [documentation](api.html#auto_LiRPA.BoundedModule)). For example, if the operator name is `MyRelu`, and the bound class is `BoundMyRelu`, then add `"MyRelu": BoundMyRelu` to the `dict`.
## Example
diff --git a/examples/.gitignore b/examples/.gitignore
new file mode 100644
index 0000000..c067428
--- /dev/null
+++ b/examples/.gitignore
@@ -0,0 +1 @@
+auto_LiRPA
diff --git a/examples/language/train.py b/examples/language/train.py
index 1d6d2a3..8c6f5b2 100644
--- a/examples/language/train.py
+++ b/examples/language/train.py
@@ -1,13 +1,14 @@
-"""
-A simple script to train certified defense LSTM or Transformer using the auto_LiRPA library.
-"""
-
import argparse
+import random
import pickle
import os
+import pdb
+import time
import logging
import numpy as np
import torch
+import torch.nn as nn
+import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationSynonym, CrossEntropyWrapperMultiInput
@@ -69,18 +70,11 @@
args = parser.parse_args()
-# log writer in Tensorboard
writer = SummaryWriter(os.path.join(args.dir, 'log'), flush_secs=10)
file_handler = logging.FileHandler(os.path.join(args.dir, 'log/train.log'))
file_handler.setFormatter(logging.Formatter('%(levelname)-8s %(asctime)-12s %(message)s'))
logger.addHandler(file_handler)
-random.seed(args.seed)
-np.random.seed(args.seed)
-torch.manual_seed(args.seed)
-torch.cuda.manual_seed_all(args.seed)
-
-## Step 1: Prepare dataset and Initial original model as usual
data_train_all_nodes, data_train, data_dev, data_test = load_data(args.data)
if args.robust:
data_dev, data_test = clean_data(data_dev), clean_data(data_test)
@@ -92,6 +86,10 @@
logger.info('Dataset sizes: {}/{}/{}/{}'.format(
len(data_train_all_nodes), len(data_train), len(data_dev), len(data_test)))
+random.seed(args.seed)
+np.random.seed(args.seed)
+torch.manual_seed(args.seed)
+torch.cuda.manual_seed_all(args.seed)
dummy_embeddings = torch.zeros(1, args.max_sent_length, args.embedding_size, device=args.device)
dummy_labels = torch.zeros(1, dtype=torch.long, device=args.device)
@@ -102,17 +100,14 @@
elif args.model == 'lstm':
dummy_mask = torch.zeros(1, args.max_sent_length, device=args.device)
model = LSTM(args, data_train)
-
+
dev_batches = get_batches(data_dev, args.batch_size)
test_batches = get_batches(data_test, args.batch_size)
-## Step 3: Define perturbation range, here we use synonym replacement perturbation constarint by args.budget
ptb = PerturbationSynonym(budget=args.budget)
dummy_embeddings = BoundedTensor(dummy_embeddings, ptb)
-
-## Step 4: wrap model with auto_LiRPA
model_ori = model.model_from_embeddings
-bound_opts = {'relu': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True}
+bound_opts = { 'relu': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True }
if isinstance(model_ori, BoundedModule):
model_bound = model_ori
else:
@@ -184,11 +179,9 @@ def step(model, ptb, batch, eps=1.0, train=False):
acc = (torch.argmax(logits, dim=1) == labels).float().mean()
if robust:
- # generate specifications
num_class = args.num_classes
c = torch.eye(num_class).type_as(embeddings)[labels].unsqueeze(1) - \
torch.eye(num_class).type_as(embeddings).unsqueeze(0)
- # remove specifications to self
I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))
c = (c[I].view(embeddings.size(0), num_class - 1, num_class))
if args.method in ['IBP', 'IBP+backward', 'forward', 'forward+backward']:
@@ -198,15 +191,11 @@ def step(model, ptb, batch, eps=1.0, train=False):
if 1 - eps > 1e-4:
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP+backward', bound_upper=False)
ilb, iub = model_bound.compute_bounds(aux=aux, C=c, method='IBP', reuse_ibp=True)
- # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
lb = eps * ilb + (1 - eps) * lb
else:
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP')
else:
raise NotImplementedError
-
- # Pad zero at the beginning for each example, and use fake label "0" for all examples because the margins
- # have already been calculated by specifications
lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1)
fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)
loss_robust = robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
@@ -237,7 +226,6 @@ def train(epoch, batches, type):
assert(optimizer is not None)
train = type == 'train'
if args.robust:
- # epsilon dynamically growth
eps_scheduler.set_epoch_length(len(batches))
if train:
eps_scheduler.train()
diff --git a/examples/requirements.txt b/examples/requirements.txt
index 279168b..2aeee45 100644
--- a/examples/requirements.txt
+++ b/examples/requirements.txt
@@ -1,11 +1,12 @@
thop>=0.0.31.post2004101309
tensorboard>=1.14
scikit_learn>=0.21
-torchvision>=0.9.1
-torch>=1.8,<1.9
+torchvision>=0.9.1,<0.13
+torch>=1.8
scipy>=1.3
pytorch_pretrained_bert>=0.6
query>=0.1
tqdm>=4.43
matplotlib>=3.2
sortedcontainers>=2.4
+psutil>=5.8
diff --git a/examples/vision/cifar_training.py b/examples/vision/cifar_training.py
index 9feb315..d4019a9 100644
--- a/examples/vision/cifar_training.py
+++ b/examples/vision/cifar_training.py
@@ -3,11 +3,11 @@
import random
import time
import logging
+import os
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
-from thop import profile
from torch.nn import CrossEntropyLoss
import models
@@ -55,7 +55,7 @@ def get_exp_module(bounded_module):
os.makedirs('saved_models/', exist_ok=True)
log_file = f'saved_models/{exp_name}{"_test" if args.verify else ""}.log'
file_handler = logging.FileHandler(log_file)
-logger.addHandler(file_handler)
+logger.addHandler(file_handler)
def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None):
num_class = 10
@@ -172,16 +172,6 @@ def get_bound_loss(x=None, c=None):
opt.step()
meter.update('Loss', loss.item(), data.size(0))
- # check gradient
- # for n, p in model.named_parameters():
- # if p.grad is None:
- # print('gradient for layer {} is NULL!!!'.format(n))
- # else:
- # print('gradient for layer {} is not null'.format(n))
- # print(p.grad.flatten()[:8])
- #
- # sys.exit()
-
if batch_method != 'natural':
meter.update('Robust_CE', robust_ce.item(), data.size(0))
if not loss_fusion:
@@ -262,9 +252,6 @@ def main(args):
final_name2 = None
model_loss = BoundDataParallel(model_loss)
- macs, params = profile(model_ori, (dummy_input.cuda(),))
- logger.info('macs: {}, params: {}'.format(macs, params))
-
## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
opt = optim.Adam(model_loss.parameters(), lr=args.lr)
norm = float(args.norm)
diff --git a/examples/vision/custom_op.py b/examples/vision/custom_op.py
index 1d5d96e..5b70fa0 100644
--- a/examples/vision/custom_op.py
+++ b/examples/vision/custom_op.py
@@ -1,7 +1,7 @@
""" A example for custom operators.
-In this example, we create a custom operator called "PlusConstant", which can
-be written as "f(x) = x + c" for some constant "c" (an attribute of the operator).
+In this example, we create a custom operator called "PlusConstant", which can
+be written as "f(x) = x + c" for some constant "c" (an attribute of the operator).
"""
import torch
import torch.nn as nn
@@ -11,22 +11,22 @@
from auto_LiRPA.perturbations import PerturbationLpNorm
from auto_LiRPA.utils import Flatten
-""" Step 1: Define a `torch.autograd.Function` class to declare and implement the
+""" Step 1: Define a `torch.autograd.Function` class to declare and implement the
computation of the operator. """
class PlusConstantOp(torch.autograd.Function):
@staticmethod
def symbolic(g, x, const):
""" In this function, define the arguments and attributes of the operator.
"custom::PlusConstant" is the name of the new operator, "x" is an argument
- of the operator, "const_i" is an attribute which stands for "c" in the operator.
- There can be multiple arguments and attributes. For attribute naming,
- use a suffix such as "_i" to specify the data type, where "_i" stands for
+ of the operator, "const_i" is an attribute which stands for "c" in the operator.
+ There can be multiple arguments and attributes. For attribute naming,
+ use a suffix such as "_i" to specify the data type, where "_i" stands for
integer, "_t" stands for tensor, "_f" stands for float, etc. """
return g.op('custom::PlusConstant', x, const_i=const)
@staticmethod
def forward(ctx, x, const):
- """ In this function, implement the computation for the operator, i.e.,
+ """ In this function, implement the computation for the operator, i.e.,
f(x) = x + c in this case. """
return x + const
@@ -43,9 +43,9 @@ def forward(self, x):
""" Step 3: Implement a Bound class to support bound computation for the new operator. """
class BoundPlusConstant(Bound):
- def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device):
+ def __init__(self, attr, inputs, output_index, options):
""" `const` is an attribute and can be obtained from the dict `attr` """
- super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)
+ super().__init__(attr, inputs, output_index, options)
self.const = attr['const']
def forward(self, x):
@@ -67,7 +67,7 @@ def _bound_oneside(last_A):
bias = last_A.sum(dim=list(range(2, last_A.ndim))) * self.const
return A, bias
lA, lbias = _bound_oneside(last_lA)
- uA, ubias = _bound_oneside(last_uA)
+ uA, ubias = _bound_oneside(last_lA)
return [(lA, uA)], lbias, ubias
def interval_propagate(self, *v):
diff --git a/examples/vision/data/.gitignore b/examples/vision/data/.gitignore
new file mode 100644
index 0000000..13ece7c
--- /dev/null
+++ b/examples/vision/data/.gitignore
@@ -0,0 +1,2 @@
+MNIST
+cifar*
\ No newline at end of file
diff --git a/examples/vision/data/tinyImageNet/.gitignore b/examples/vision/data/tinyImageNet/.gitignore
new file mode 100644
index 0000000..0c46dd7
--- /dev/null
+++ b/examples/vision/data/tinyImageNet/.gitignore
@@ -0,0 +1 @@
+tiny-imagenet-200*
diff --git a/examples/vision/imagenet_training.py b/examples/vision/imagenet_training.py
index 6e9791f..22425c9 100644
--- a/examples/vision/imagenet_training.py
+++ b/examples/vision/imagenet_training.py
@@ -54,7 +54,7 @@ def get_exp_module(bounded_module):
args.num_epochs) + '_' + args.scheduler_opts + '_' + str(args.eps)[:6]
log_file = f'saved_models/{exp_name}{"_test" if args.verify else ""}.log'
file_handler = logging.FileHandler(log_file)
-logger.addHandler(file_handler)
+logger.addHandler(file_handler)
def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True,
final_node_name=None):
diff --git a/examples/vision/simple_verification.py b/examples/vision/simple_verification.py
index 9b7629e..b20c61b 100644
--- a/examples/vision/simple_verification.py
+++ b/examples/vision/simple_verification.py
@@ -4,17 +4,15 @@
This example serves as a skeleton for robustness verification of neural networks.
"""
import os
+from collections import defaultdict
import torch
import torch.nn as nn
import torchvision
from auto_LiRPA import BoundedModule, BoundedTensor
from auto_LiRPA.perturbations import PerturbationLpNorm
+from auto_LiRPA.utils import Flatten
## Step 1: Define computational graph by implementing forward()
-class Flatten(nn.Module):
- def forward(self, x):
- return x.view(x.size(0), -1)
-
# This simple model comes from https://github.com/locuslab/convex_adversarial
def mnist_model():
model = nn.Sequential(
@@ -31,11 +29,15 @@ def mnist_model():
model = mnist_model()
# Optionally, load the pretrained weights.
-checkpoint = torch.load(os.path.join(os.path.dirname(__file__),"pretrain/mnist_a_adv.pth"), map_location=torch.device('cpu'))
+checkpoint = torch.load(
+ os.path.join(os.path.dirname(__file__), 'pretrain/mnist_a_adv.pth'),
+ map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
## Step 2: Prepare dataset as usual
-test_data = torchvision.datasets.MNIST("./data", train=False, download=True, transform=torchvision.transforms.ToTensor())
+test_data = torchvision.datasets.MNIST(
+ './data', train=False, download=True,
+ transform=torchvision.transforms.ToTensor())
# For illustration we only use 2 image from dataset
N = 2
n_classes = 10
@@ -48,7 +50,8 @@ def mnist_model():
model = model.cuda()
## Step 3: wrap model with auto_LiRPA
-# The second parameter is for constructing the trace of the computational graph, and its content is not important.
+# The second parameter is for constructing the trace of the computational graph,
+# and its content is not important.
lirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device)
print('Running on', image.device)
@@ -60,23 +63,51 @@ def mnist_model():
# Get model prediction as usual
pred = lirpa_model(image)
label = torch.argmax(pred, dim=1).cpu().detach().numpy()
+print('Demonstration 1: Bound computation and comparisons of different methods.\n')
## Step 5: Compute bounds for final output
-for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', 'CROWN-Optimized (alpha-CROWN)']:
-# for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', ]:
- print("Bounding method:", method)
+for method in [
+ 'IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)',
+ 'CROWN-Optimized (alpha-CROWN)']:
+ print('Bounding method:', method)
if 'Optimized' in method:
# For optimized bound, you can change the number of iterations, learning rate, etc here. Also you can increase verbosity to see per-iteration loss values.
- lirpa_model.set_bound_opts({'optimize_bound_args': {'ob_iteration': 20, 'ob_lr': 0.1, 'ob_verbose': 0}})
+ lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}})
lb, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0])
for i in range(N):
- print("Image {} top-1 prediction {} ground-truth {}".format(i, label[i], true_label[i]))
+ print(f'Image {i} top-1 prediction {label[i]} ground-truth {true_label[i]}')
for j in range(n_classes):
indicator = '(ground-truth)' if j == true_label[i] else ''
- print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}".format(
+ print('f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}'.format(
j=j, l=lb[i][j].item(), u=ub[i][j].item(), ind=indicator))
print()
+print('Demonstration 2: Obtaining linear coefficients of the lower and upper bounds.\n')
+# There are many bound coefficients during CROWN bound calculation; here we are interested in the linear bounds
+# of the output layer, with respect to the input layer (the image).
+required_A = defaultdict(set)
+required_A[lirpa_model.output_name[0]].add(lirpa_model.input_name[0])
+
+for method in [
+ 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', 'CROWN',
+ 'CROWN-Optimized (alpha-CROWN)']:
+ print("Bounding method:", method)
+ if 'Optimized' in method:
+ # For optimized bound, you can change the number of iterations, learning rate, etc here. Also you can increase verbosity to see per-iteration loss values.
+ lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}})
+ lb, ub, A_dict = lirpa_model.compute_bounds(x=(image,), method=method.split()[0], return_A=True, needed_A_dict=required_A)
+ lower_A, lower_bias = A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['lA'], A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['lbias']
+ upper_A, upper_bias = A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['uA'], A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['ubias']
+ print(f'lower bound linear coefficients size (batch, output_dim, *input_dims): {list(lower_A.size())}')
+ print(f'lower bound linear coefficients norm (smaller is better): {lower_A.norm()}')
+ print(f'lower bound bias term size (batch, output_dim): {list(lower_bias.size())}')
+ print(f'lower bound bias term sum (larger is better): {lower_bias.sum()}')
+ print(f'upper bound linear coefficients size (batch, output_dim, *input_dims): {list(upper_A.size())}')
+ print(f'upper bound linear coefficients norm (smaller is better): {upper_A.norm()}')
+ print(f'upper bound bias term size (batch, output_dim): {list(upper_bias.size())}')
+ print(f'upper bound bias term sum (smaller is better): {upper_bias.sum()}')
+ print(f'These linear lower and upper bounds are valid everywhere within the perturbation radii.\n')
+
## An example for computing margin bounds.
# In compute_bounds() function you can pass in a specification matrix C, which is a final linear matrix applied to the last layer NN output.
# For example, if you are interested in the margin between the groundtruth class and another class, you can use C to specify the margin.
@@ -89,17 +120,18 @@ def mnist_model():
target_label = (groundtruth + 1) % n_classes
C.scatter_(dim=2, index=groundtruth, value=1.0)
C.scatter_(dim=2, index=target_label, value=-1.0)
-print('Computing bounds with a specification matrix:\n', C)
+print('Demonstration 3: Computing bounds with a specification matrix.\n')
+print('Specification matrix:\n', C)
for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', 'CROWN-Optimized (alpha-CROWN)']:
- print("Bounding method:", method)
+ print('Bounding method:', method)
if 'Optimized' in method:
# For optimized bound, you can change the number of iterations, learning rate, etc here. Also you can increase verbosity to see per-iteration loss values.
- lirpa_model.set_bound_opts({'optimize_bound_args': {'ob_iteration': 20, 'ob_lr': 0.1, 'ob_verbose': 0}})
+ lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1, }})
lb, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0], C=C)
for i in range(N):
- print("Image {} top-1 prediction {} ground-truth {}".format(i, label[i], true_label[i]))
- print("margin bounds: {l:8.3f} <= f_{j}(x_0+delta) - f_{target}(x_0+delta) <= {u:8.3f}".format(
+ print('Image {} top-1 prediction {} ground-truth {}'.format(i, label[i], true_label[i]))
+ print('margin bounds: {l:8.3f} <= f_{j}(x_0+delta) - f_{target}(x_0+delta) <= {u:8.3f}'.format(
j=true_label[i], target=(true_label[i] + 1) % n_classes, l=lb[i][0].item(), u=ub[i][0].item()))
print()
diff --git a/examples/vision/tinyimagenet_training.py b/examples/vision/tinyimagenet_training.py
index 3f079a5..3264ca2 100644
--- a/examples/vision/tinyimagenet_training.py
+++ b/examples/vision/tinyimagenet_training.py
@@ -1,3 +1,4 @@
+import os
import random
import time
import argparse
@@ -55,7 +56,7 @@ def get_exp_module(bounded_module):
os.makedirs('saved_models/', exist_ok=True)
log_file = f'saved_models/{exp_name}{"_test" if args.verify else ""}.log'
file_handler = logging.FileHandler(log_file)
-logger.addHandler(file_handler)
+logger.addHandler(file_handler)
def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True,
final_node_name=None):
diff --git a/examples/vision/verify_two_node.py b/examples/vision/verify_two_node.py
index 7bbd697..9895d8a 100644
--- a/examples/vision/verify_two_node.py
+++ b/examples/vision/verify_two_node.py
@@ -1,10 +1,11 @@
"""
-Example for multi-node perturbation. An input image is splited to two parts
-where each part is perturbed respectively constained by L-inf norm. It is
-expected to output the same results as running `simple_verification.py` where
+Example for multi-node perturbation. An input image is splited to two parts
+where each part is perturbed respectively constained by L-inf norm. It is
+expected to output the same results as running `simple_verification.py` where
the whole image is perturbed constained by L-inf norm.
"""
+import os
import torch.nn as nn
import torch.nn.functional as F
import torchvision
@@ -35,7 +36,8 @@ def forward(self, x, y):
model.load_state_dict(checkpoint)
## Step 2: Prepare dataset as usual
-test_data = torchvision.datasets.MNIST("./data", train=False, download=True, transform=torchvision.transforms.ToTensor())
+test_data = torchvision.datasets.MNIST(
+ "./data", train=False, download=True, transform=torchvision.transforms.ToTensor())
# For illustration we only use 2 image from dataset
N = 2
n_classes = 10
@@ -44,9 +46,13 @@ def forward(self, x, y):
image = image.to(torch.float32) / 255.0
## Step 3: wrap model with auto_LiRPA
-# The second parameter is for constructing the trace of the computational graph, and its content is not important.
+# The second parameter is for constructing the trace of the computational graph,
+# and its content is not important.
image_1, image_2 = torch.split(torch.empty_like(image), [14, 14], dim=2)
-model = BoundedModule(model, (image_1, image_2), device="cuda")
+model = BoundedModule(
+ model, (image_1, image_2), device="cuda",
+ bound_opts={'conv_mode': 'matrix'} # Patches mode is not supported currently
+)
## Step 4: Compute bounds using LiRPA given a perturbation
eps = 0.3
@@ -68,6 +74,6 @@ def forward(self, x, y):
for i in range(N):
print("Image {} top-1 prediction {}".format(i, label[i]))
for j in range(n_classes):
- print("f_{j}(x_0) = {fx0:8.3f}, {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}".format(j=j, fx0=pred[i][j], l=lb[i][j], u=ub[i][j]))
+ print("f_{j}(x_0) = {fx0:8.3f}, {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}".format(
+ j=j, fx0=pred[i][j], l=lb[i][j], u=ub[i][j]))
print()
-
diff --git a/examples/vision/weight_perturbation_training.py b/examples/vision/weight_perturbation_training.py
index dbc2bef..13bad55 100644
--- a/examples/vision/weight_perturbation_training.py
+++ b/examples/vision/weight_perturbation_training.py
@@ -11,6 +11,7 @@
"""
import random
import time
+import os
import argparse
import logging
import torch.optim as optim
@@ -217,17 +218,22 @@ def main(args):
## Step 3: wrap model with auto_LiRPA
# The second parameter dummy_input is for constructing the trace of the computational graph.
- model = BoundedModule(model_ori, dummy_input, bound_opts={'relu':args.bound_opts}, device=args.device)
+ model = BoundedModule(model_ori, dummy_input, device=args.device, bound_opts={
+ 'relu':args.bound_opts, 'sparse_intermediate_bounds': False,
+ 'sparse_conv_intermediate_bounds': False, 'sparse_intermediate_bounds_with_ibp': False})
final_name1 = model.final_name
model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)),
- bound_opts= { 'relu': args.bound_opts, 'loss_fusion': True }, device=args.device)
+ device=args.device, bound_opts= {'relu': args.bound_opts, 'loss_fusion': True,
+ 'sparse_intermediate_bounds': False,
+ 'sparse_conv_intermediate_bounds': False,
+ 'sparse_intermediate_bounds_with_ibp': False})
# after CrossEntropyWrapper, the final name will change because of one more input node in CrossEntropyWrapper
final_name2 = model_loss._modules[final_name1].output_name[0]
assert type(model._modules[final_name1]) == type(model_loss._modules[final_name2])
if args.multigpu:
model_loss = BoundDataParallel(model_loss)
- model_loss.ptb = model.ptb = model_ori.ptb # Perturbation on the parameters
+ model_loss.ptb = model.ptb = model_ori.ptb # Perturbation on the parameters
## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
if args.opt == 'ADAM':
diff --git a/setup.py b/setup.py
index 0807d51..5428841 100644
--- a/setup.py
+++ b/setup.py
@@ -1,9 +1,8 @@
from setuptools import setup, find_packages
-import sys
"""Check PyTorch version"""
pytorch_version_l = "1.8.0"
-pytorch_version_u = "1.9.0" # excluded
+pytorch_version_u = "1.13.0" # excluded
msg_install_pytorch = (f"It is recommended to manually install PyTorch "
f"(>={pytorch_version_u},<{pytorch_version_u}) suitable "
"for your system ahead: https://pytorch.org/get-started.\n")
@@ -14,7 +13,7 @@
+ msg_install_pytorch)
if torch.__version__ >= pytorch_version_u:
print(f'PyTorch version {torch.__version__} is too high. '
- + msg_install_pytorch)
+ + msg_install_pytorch)
except ModuleNotFoundError:
print(f'PyTorch is not installed. {msg_install_pytorch}')
@@ -23,13 +22,6 @@
if '__version__' in line:
version = eval(line.strip().split()[-1])
-assert sys.version_info.major == 3, 'Python 3 is required'
-if sys.version_info.minor < 8:
- # numpy 1.22 requires Python 3.8+
- numpy_requirement = 'numpy>=1.16,<=1.21'
-else:
- numpy_requirement = 'numpy>=1.16'
-
print(f'Installing auto_LiRPA {version}')
setup(
name='auto_LiRPA',
@@ -41,12 +33,15 @@
packages=find_packages(),
install_requires=[
f'torch>={pytorch_version_l},<{pytorch_version_u}',
- 'torchvision>=0.9,<0.10',
- numpy_requirement,
+ 'torchvision>=0.9,<0.14',
+ 'numpy>=1.16',
'packaging>=20.0',
'pytest>=5.0',
+ 'pylint>=2.15',
+ 'pytest-order>=1.0.0',
'appdirs>=1.4',
'pyyaml>=5.0',
+ 'ninja>=1.10',
],
platforms=['any'],
license='BSD',
diff --git a/tests/.gitignore b/tests/.gitignore
new file mode 100644
index 0000000..16d3c4d
--- /dev/null
+++ b/tests/.gitignore
@@ -0,0 +1 @@
+.cache
diff --git a/tests/auto_LiRPA b/tests/auto_LiRPA
new file mode 120000
index 0000000..a0acdc3
--- /dev/null
+++ b/tests/auto_LiRPA
@@ -0,0 +1 @@
+../auto_LiRPA
\ No newline at end of file
diff --git a/tests/data/.gitignore b/tests/data/.gitignore
index 677d9db..c40e49a 100644
--- a/tests/data/.gitignore
+++ b/tests/data/.gitignore
@@ -2,4 +2,4 @@ ckpt_lstm
ckpt_transformer
cifar-10-python.tar.gz
cifar-10-batches-py
-
+MNIST
\ No newline at end of file
diff --git a/tests/data/beta_crown_test_data b/tests/data/beta_crown_test_data
new file mode 100644
index 0000000..6090ff8
Binary files /dev/null and b/tests/data/beta_crown_test_data differ
diff --git a/tests/data/maxpool_test_data b/tests/data/maxpool_test_data
index 4d230bc..b538cc9 100755
Binary files a/tests/data/maxpool_test_data and b/tests/data/maxpool_test_data differ
diff --git a/tests/data/maxpool_test_data_3-0-3-0 b/tests/data/maxpool_test_data_3-0-3-0
new file mode 100644
index 0000000..41191fa
Binary files /dev/null and b/tests/data/maxpool_test_data_3-0-3-0 differ
diff --git a/tests/data/maxpool_test_data_3-0-3-1 b/tests/data/maxpool_test_data_3-0-3-1
new file mode 100644
index 0000000..351e86d
Binary files /dev/null and b/tests/data/maxpool_test_data_3-0-3-1 differ
diff --git a/tests/data/maxpool_test_data_4-0-4-0 b/tests/data/maxpool_test_data_4-0-4-0
new file mode 100644
index 0000000..e054b4a
Binary files /dev/null and b/tests/data/maxpool_test_data_4-0-4-0 differ
diff --git a/tests/data/maxpool_test_data_4-0-4-1 b/tests/data/maxpool_test_data_4-0-4-1
new file mode 100644
index 0000000..83f17ef
Binary files /dev/null and b/tests/data/maxpool_test_data_4-0-4-1 differ
diff --git a/tests/data/vision_test_data b/tests/data/vision_test_data
index 31715a5..9f2b7fd 100644
Binary files a/tests/data/vision_test_data and b/tests/data/vision_test_data differ
diff --git a/tests/test_1d_activation.py b/tests/test_1d_activation.py
index 22dfa2a..232b47b 100644
--- a/tests/test_1d_activation.py
+++ b/tests/test_1d_activation.py
@@ -1,5 +1,6 @@
"""Test one dimensional activation functions (e.g., ReLU, tanh, exp, sin, etc)"""
import torch
+import torch.nn as nn
import os
from testcase import TestCase
from auto_LiRPA import BoundedModule, BoundedTensor
@@ -21,6 +22,8 @@ def __init__(self, methodName='runTest'):
super().__init__(methodName)
def create_test(self, act_func, low, high, ntests=10000, nsamples=1000, method='IBP'):
+ print(f'Testing activation {act_func}')
+
model = test_model(act_func)
image = torch.zeros(1, ntests)
bounded_model = BoundedModule(model, image)
@@ -83,6 +86,22 @@ def lookup(l, u):
assert torch.all(output_lb - 1e-5 <= ref_output_lb)
+ def _single(self):
+ model = test_model(torch.sin)
+ image = torch.zeros(1, 1)
+ bounded_model = BoundedModule(model, image)
+
+ input_lb = torch.tensor([2.817])
+ input_ub = torch.tensor([5.196])
+ input_center = (input_lb + input_ub) / 2.0
+ ptb = PerturbationLpNorm(norm=float("inf"), eps=None, x_L=input_lb, x_U=input_ub)
+ ptb_data = BoundedTensor(input_center, ptb)
+
+ # Get bounding results.
+ forward = bounded_model(ptb_data)
+ output_lb, output_ub = bounded_model.compute_bounds(x=(ptb_data,), method = 'CROWN')
+ print(output_lb, output_ub)
+
def test_relu(self):
self.create_test(act_func=torch.nn.functional.relu, low=-10, high=10, method='IBP')
self.create_test(act_func=torch.nn.functional.relu, low=-10, high=10, method='CROWN')
@@ -106,13 +125,22 @@ def test_tanh(self):
def test_sin(self):
self.create_test(act_func=torch.sin, low=-10, high=10, method='IBP')
- # self.create_test(act_func=torch.sin, low=-10, high=10, method='CROWN')
+ self.create_test(act_func=torch.sin, low=-10, high=10, method='CROWN')
def test_cos(self):
self.create_test(act_func=torch.cos, low=-10, high=10, method='IBP')
- # self.create_test(act_func=torch.cos, low=-10, high=10, method='CROWN')
+ self.create_test(act_func=torch.cos, low=-10, high=10, method='CROWN')
+ def test_arctan(self):
+ self.create_test(act_func=torch.arctan, low=-10, high=10, method='IBP')
+ self.create_test(act_func=torch.arctan, low=-10, high=10, method='CROWN')
+
+ def test_tan(self):
+ # Test tan(x) in different periods.
+ for i in range(-5, 5):
+ self.create_test(act_func=torch.arctan, low=-0.5*torch.pi + i*torch.pi + 1e-20, high=0.5*torch.pi + i*torch.pi - 1e-20, method='IBP')
+ self.create_test(act_func=torch.arctan, low=-0.5*torch.pi + i*torch.pi + 1e-20, high=0.5*torch.pi + i*torch.pi - 1e-20, method='CROWN')
if __name__ == '__main__':
testcase = Test1DActivation()
@@ -122,5 +150,5 @@ def test_cos(self):
testcase.test_tanh()
testcase.test_sin()
testcase.test_cos()
-
-
+ testcase.test_arctan()
+ testcase.test_tan()
diff --git a/tests/test_pooling.py b/tests/test_avgpool.py
similarity index 96%
rename from tests/test_pooling.py
rename to tests/test_avgpool.py
index bc5ec23..cf4b45b 100644
--- a/tests/test_pooling.py
+++ b/tests/test_avgpool.py
@@ -1,4 +1,4 @@
-# Test bounds on a 1 layer linear network.
+"""Test average pooling."""
import torch.nn as nn
from auto_LiRPA import BoundedModule, BoundedTensor
diff --git a/tests/test_bound_ops.py b/tests/test_bound_ops.py
index 5b3b157..a49adc4 100644
--- a/tests/test_bound_ops.py
+++ b/tests/test_bound_ops.py
@@ -1,10 +1,10 @@
"""Test classes for bound operators"""
import torch
-import os
from auto_LiRPA.bound_ops import *
-from auto_LiRPA.utils import LinearBound
+from auto_LiRPA.linear_bound import LinearBound
from testcase import TestCase
+
"""Dummy node for testing"""
class Dummy:
def __init__(self, lower, upper=None, perturbed=False):
@@ -36,10 +36,9 @@ def test(self):
dummy_bias = Dummy(bias)
op = BoundLinear(
- input_name=[None, None, None],
- name=None, ori_name=None, attr=None,
+ attr={},
inputs=[dummy_in, dummy_weight, dummy_bias],
- output_index=0, options={}, device=device)
+ output_index=0, options={})
# test `forward`
data_out = op(data_in, weight, bias)
diff --git a/tests/test_constant.py b/tests/test_constant.py
index cc4a42e..756a125 100644
--- a/tests/test_constant.py
+++ b/tests/test_constant.py
@@ -57,4 +57,5 @@ def test(self):
if __name__ == '__main__':
# Change to generate=True when genearting reference results
testcase = TestConstant(generate=False)
+ testcase.setUp()
testcase.test()
diff --git a/tests/test_conv.py b/tests/test_conv.py
index 9390384..b9e571f 100644
--- a/tests/test_conv.py
+++ b/tests/test_conv.py
@@ -15,7 +15,7 @@ def forward(self, x):
return x.view((x.shape[0], -1))
class cnn_model(nn.Module):
- def __init__(self, layers, padding, stride):
+ def __init__(self, layers, padding, stride, linear=True):
super(cnn_model, self).__init__()
self.module_list = []
channel = 1
@@ -27,8 +27,9 @@ def __init__(self, layers, padding, stride):
assert length > 0
self.module_list.append(nn.ReLU())
self.module_list.append(Flatten())
- self.module_list.append(nn.Linear(3 * length * length, 256))
- self.module_list.append(nn.Linear(256, 10))
+ if linear:
+ self.module_list.append(nn.Linear(3 * length * length, 256))
+ self.module_list.append(nn.Linear(256, 10))
self.model = nn.Sequential(*self.module_list)
def forward(self, x):
@@ -42,7 +43,7 @@ def __init__(self, methodName='runTest', generate=False):
generate=generate)
def test(self):
- models = [2, 3]
+ models = [1, 2, 3]
paddings = [1, 2]
strides = [1, 3]
@@ -54,26 +55,32 @@ def test(self):
for layer_num in models:
for padding in paddings:
for stride in strides:
- try:
+ for linear in [True, False]:
model_ori = cnn_model(layer_num, padding, stride)
- except:
- continue
+ print('Model:', model_ori)
- model = BoundedModule(model_ori, image, device="cpu", bound_opts={"conv_mode": "patches"})
- eps = 0.3
- norm = np.inf
- ptb = PerturbationLpNorm(norm=norm, eps=eps)
- image = BoundedTensor(image, ptb)
- pred = model(image)
- lb, ub = model.compute_bounds()
+ model = BoundedModule(model_ori, image, bound_opts={"conv_mode": "patches"})
+ eps = 0.3
+ norm = np.inf
+ ptb = PerturbationLpNorm(norm=norm, eps=eps)
+ image = BoundedTensor(image, ptb)
+ pred = model(image)
+ lb, ub = model.compute_bounds()
- model = BoundedModule(model_ori, image, device="cpu", bound_opts={"conv_mode": "matrix"})
- pred = model(image)
- lb_ref, ub_ref = model.compute_bounds()
+ model = BoundedModule(model_ori, image, bound_opts={"conv_mode": "matrix"})
+ pred = model(image)
+ lb_ref, ub_ref = model.compute_bounds()
- assert lb.shape == ub.shape == torch.Size((N, n_classes))
- self.assertEqual(lb, lb_ref)
- self.assertEqual(ub, ub_ref)
+ if linear:
+ assert lb.shape == ub.shape == torch.Size((N, n_classes))
+ self.assertEqual(lb, lb_ref)
+ self.assertEqual(ub, ub_ref)
+
+ if not linear and layer_num == 1:
+ pred = model(image)
+ lb_forward, ub_forward = model.compute_bounds(method='forward')
+ self.assertEqual(lb, lb_forward)
+ self.assertEqual(ub, ub_forward)
if __name__ == '__main__':
testcase = TestConv()
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 0fae3d7..f2de101 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -1,4 +1,7 @@
-"""Test all the examples before release"""
+"""Test all the examples before release.
+
+This script is expected be manually run and is not used in automatic tests."""
+
import pytest
import subprocess
import os
@@ -8,7 +11,7 @@
parser = argparse.ArgumentParser()
parser.add_argument('--test', type=str, default=None)
-args = parser.parse_args()
+args = parser.parse_args()
pytest_skip = pytest.mark.skip(
reason="It should be tested on a GPU server and excluded from CI")
@@ -24,14 +27,14 @@ def download_data_language():
url = "http://download.huan-zhang.com/datasets/language/data_language.tar.gz"
if not os.path.exists('../examples/language/data/sst'):
subprocess.run(shlex.split(f"wget {url}"), cwd="../examples/language")
- subprocess.run(shlex.split(f"tar xvf data_language.tar.gz"),
+ subprocess.run(shlex.split(f"tar xvf data_language.tar.gz"),
cwd="../examples/language")
@pytest_skip
def test_transformer():
cmd = f"""python train.py --dir {cache_dir} --robust
--method IBP+backward_train --train --num_epochs 2 --num_epochs_all_nodes 2
- --eps_start 2 --eps_length 1 --eps 0.1"""
+ --eps_start 2 --eps_length 1 --eps 0.1"""
print(cmd, file=sys.stderr)
download_data_language()
subprocess.run(shlex.split(cmd), cwd='../examples/language')
@@ -47,19 +50,36 @@ def test_lstm():
download_data_language()
subprocess.run(shlex.split(cmd), cwd='../examples/language')
-#FIXME this is broken
@pytest_skip
def test_lstm_seq():
cmd = f"""python train.py --dir {cache_dir}
--hidden_size 2 --num_epochs 2 --num_slices 4"""
print(cmd, file=sys.stderr)
- subprocess.run(shlex.split(cmd), cwd='../examples/sequence')
+ subprocess.run(shlex.split(cmd), cwd='../examples/sequence')
@pytest_skip
def test_simple_verification():
cmd = "python simple_verification.py"
print(cmd, file=sys.stderr)
- subprocess.run(shlex.split(cmd), cwd='../examples/vision')
+ subprocess.run(shlex.split(cmd), cwd='../examples/vision')
+
+@pytest_skip
+def test_custom_op():
+ cmd = "python custom_op.py"
+ print(cmd, file=sys.stderr)
+ subprocess.run(shlex.split(cmd), cwd='../examples/vision')
+
+@pytest_skip
+def test_efficient_convolution():
+ cmd = "python efficient_convolution.py"
+ print(cmd, file=sys.stderr)
+ subprocess.run(shlex.split(cmd), cwd='../examples/vision')
+
+@pytest_skip
+def test_two_node():
+ cmd = "python verify_two_node.py"
+ print(cmd, file=sys.stderr)
+ subprocess.run(shlex.split(cmd), cwd='../examples/vision')
@pytest_skip
def test_simple_training():
@@ -92,8 +112,8 @@ def test_tinyimagenet():
--in_planes 2 --widen_factor 2"""
print(cmd, file=sys.stderr)
if not os.path.exists('../examples/vision/data/tinyImageNet/tiny-imagenet-200'):
- subprocess.run(shlex.split("bash tinyimagenet_download.sh"),
- cwd="../examples/vision/data/tinyImageNet")
+ subprocess.run(shlex.split("bash tinyimagenet_download.sh"),
+ cwd="../examples/vision/data/tinyImageNet")
subprocess.run(shlex.split(cmd), cwd='../examples/vision')
@pytest_skip
@@ -103,7 +123,7 @@ def test_imagenet():
--num_epochs 3 --scheduler_opts start=2,length=1 --eps {0.1/255}
--in_planes 2 --widen_factor 2"""
print(cmd)
- if (not os.path.exists('../examples/vision/data/ImageNet64/train') or
+ if (not os.path.exists('../examples/vision/data/ImageNet64/train') or
not os.path.exists('../examples/vision/data/ImageNet64/test')):
print('Error: ImageNet64 dataset is not ready.')
return -1
@@ -111,19 +131,18 @@ def test_imagenet():
@pytest_skip
def test_release():
- if args.test:
- # Only run a specified test
- eval(f'test_{args.test}')()
- else:
- # Run all tests
- test_simple_training()
- test_transformer()
- test_lstm()
- test_lstm_seq()
- test_simple_verification()
- test_cifar_training()
- test_weight_perturbation()
- test_tinyimagenet()
+ """Run all tests."""
+ test_simple_training()
+ test_transformer()
+ test_lstm()
+ test_lstm_seq()
+ test_simple_verification()
+ test_cifar_training()
+ test_weight_perturbation()
+ test_tinyimagenet()
+ test_custom_op()
+ test_efficient_convolution()
+ test_two_node()
if __name__ == '__main__':
test_release()
diff --git a/tests/test_language_models.py b/tests/test_language_models.py
index 72cf035..aa4f6aa 100644
--- a/tests/test_language_models.py
+++ b/tests/test_language_models.py
@@ -43,10 +43,14 @@ def train():
os.system("rm -rf ../examples/language/model_transformer_test")
if os.path.exists("../examples/language/model_lstm_test"):
os.system("rm -rf ../examples/language/model_lstm_test")
- logger.info("Training a Transformer")
+ logger.info("\nTraining a Transformer")
+ print(cmd_transformer_train)
+ print()
os.system(cmd_transformer_train)
os.system("cp ../examples/language/model_transformer_test/ckpt_2 data/ckpt_transformer")
- logger.info("Training an LSTM")
+ logger.info("\nTraining an LSTM")
+ print(cmd_lstm_train)
+ print()
os.system(cmd_lstm_train)
os.system("cp ../examples/language/model_lstm_test/ckpt_2 data/ckpt_lstm")
@@ -55,10 +59,14 @@ def read_res():
return pickle.load(file)
def evaluate():
- logger.info('Evaluating the trained LSTM')
+ logger.info('\nEvaluating the trained LSTM')
+ print(cmd_lstm_test)
+ print()
os.system(cmd_lstm_test)
res_lstm = read_res()
- logger.info('Evaluating the trained Transformer')
+ logger.info('\nEvaluating the trained Transformer')
+ print(cmd_transformer_test)
+ print()
os.system(cmd_transformer_test)
res_transformer = read_res()
os.system("rm {}".format(res_path))
diff --git a/tests/test_linear_cnn_model.py b/tests/test_linear_cnn_model.py
index 975336b..eb0cb10 100644
--- a/tests/test_linear_cnn_model.py
+++ b/tests/test_linear_cnn_model.py
@@ -1,4 +1,4 @@
-# Test bounds on a 1 layer CNN network.
+"""Test bounds on a 1 layer CNN network."""
import torch.nn as nn
from auto_LiRPA import BoundedModule, BoundedTensor
@@ -19,7 +19,7 @@ def forward(self, x):
x = x.view(-1, input_dim //2 * input_dim // 2 * out_channel)
return x
-class TestLinearCNNModel(TestLinearModel):
+class TestLinearCNNModel(TestLinearModel):
def __init__(self, methodName='runTest', generate=False):
super().__init__(methodName)
self.original_model = LinearCNNModel()
diff --git a/tests/test_linear_model.py b/tests/test_linear_model.py
index 305d661..b762195 100644
--- a/tests/test_linear_model.py
+++ b/tests/test_linear_model.py
@@ -1,4 +1,4 @@
-# Test bounds on a 1 layer linear network.
+"""Test bounds on a 1 layer linear network."""
import torch.nn as nn
from auto_LiRPA import BoundedModule, BoundedTensor
@@ -17,7 +17,7 @@ def forward(self, x):
x = self.fc(x)
return x
-class TestLinearModel(TestCase):
+class TestLinearModel(TestCase):
def __init__(self, methodName='runTest', generate=False):
super().__init__(methodName, seed=0)
self.original_model = LinearModel()
diff --git a/tests/test_maxpool.py b/tests/test_maxpool.py
index 3b731d1..e77e15c 100644
--- a/tests/test_maxpool.py
+++ b/tests/test_maxpool.py
@@ -1,82 +1,120 @@
+"""Test max pooling."""
+
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from auto_LiRPA import BoundedModule, BoundedTensor
-from auto_LiRPA.perturbations import *
+from auto_LiRPA.perturbations import *
+from auto_LiRPA.utils import Flatten
from testcase import TestCase
-class Flatten(nn.Module):
- def __init__(self):
- super(Flatten, self).__init__()
-
- def forward(self, x):
- return x.view((x.shape[0], -1))
+def MadryCNN():
+ return nn.Sequential(
+ nn.Conv2d(1, 32, 5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.MaxPool2d(2, stride=2),
+ nn.Conv2d(32, 64, 5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.MaxPool2d(2, stride=2),
+ Flatten(),
+ nn.Linear(64*7*7,1024),
+ nn.ReLU(),
+ nn.Linear(1024, 10)
+ )
class Model(nn.Module):
- def __init__(self):
- super(Model, self).__init__()
- self.n_n_conv2d = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 32, 'padding': [0, 0], 'kernel_size': (2, 2), 'stride': [1, 1], 'in_channels': 1, 'bias': True})
- self.n_n_average_pooling2d = nn.MaxPool2d(**{'kernel_size': [4, 4], 'ceil_mode': False, 'stride': [4, 4], 'padding': [0, 0]})
- self.n_n_flatten_Flatten = nn.Flatten(**{'start_dim': 1})
- self.n_n_dense = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 10, 'padding': [0, 0], 'kernel_size': (1, 1), 'stride': [1, 1], 'in_channels': 1152, 'bias': True})
- self.n_n_activation_Flatten = nn.Flatten(**{'start_dim': 1})
-
- def forward(self, *inputs):
- t_ImageInputLayer, = inputs
- t_conv2d = self.n_n_conv2d(t_ImageInputLayer)
- t_conv2d_relu = F.relu(t_conv2d)
- t_average_pooling2d = self.n_n_average_pooling2d(t_conv2d_relu)[:, :, :, :]
- t_flatten_Transpose = t_average_pooling2d.permute(*[0, 2, 3, 1])
- t_flatten_Flatten = self.n_n_flatten_Flatten(t_flatten_Transpose)
- t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Flatten, 2)
- t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Unsqueeze, 3)
- t_dense = self.n_n_dense(t_flatten_Unsqueeze)
- t_activation_Flatten = self.n_n_activation_Flatten(t_dense)
- return t_activation_Flatten
-
-class TestConv(TestCase):
+ def __init__(self, kernel_size=4, stride=4, padding=0, conv_padding=0):
+ super(Model, self).__init__()
+ self.n_n_conv2d = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 1, 'padding': [0, 0], 'kernel_size': (2, 2), 'stride': [1, 1], 'in_channels': 1, 'bias': True})
+ self.n_n_maxpool = nn.MaxPool2d(**{'kernel_size': [kernel_size, kernel_size], 'ceil_mode': False, 'stride': [stride, stride], 'padding': [padding, padding]})
+ self.n_n_conv2d_2 = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 8, 'padding': [conv_padding, conv_padding], 'kernel_size': (2, 2), 'stride': [1, 1], 'in_channels': 1, 'bias': True})
+ self.n_n_maxpool_2 = nn.MaxPool2d(**{'kernel_size': [kernel_size, kernel_size], 'ceil_mode': False, 'stride': [stride, stride], 'padding': [padding, padding]})
+ self.n_n_flatten_Flatten = nn.Flatten(**{'start_dim': 1})
+
+ self.n_n_dense = None
+
+ self.n_n_activation_Flatten = nn.Flatten(**{'start_dim': 1})
+
+ def forward(self, *inputs):
+ t_ImageInputLayer, = inputs
+ t_conv2d = self.n_n_conv2d(t_ImageInputLayer)
+ t_conv2d_relu = F.relu(t_conv2d)
+ t_maxpool = self.n_n_maxpool(t_conv2d_relu)[:, :, :, :]
+ t_conv2d_max = self.n_n_conv2d_2(t_maxpool)
+ t_conv2d_max = F.relu(t_conv2d_max)
+ t_maxpool_2 = self.n_n_maxpool_2(t_conv2d_max)
+ t_flatten_Transpose = t_maxpool_2.permute(*[0, 2, 3, 1])
+ t_flatten_Flatten = self.n_n_flatten_Flatten(t_flatten_Transpose)
+ t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Flatten, 2)
+ t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Unsqueeze, 3)
+
+ if self.n_n_dense is None:
+ self.n_n_dense = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 10, 'padding': [0, 0], 'kernel_size': (1, 1), 'stride': [1, 1], 'in_channels': t_flatten_Unsqueeze.shape[1], 'bias': True})
+ t_dense = self.n_n_dense(t_flatten_Unsqueeze)
+ t_activation_Flatten = self.n_n_activation_Flatten(t_dense)
+
+ return t_activation_Flatten
+
+class TestMaxPool(TestCase):
def __init__(self, methodName='runTest', generate=False):
- super().__init__(methodName,
+ super().__init__(methodName,
seed=1, ref_path=None,
generate=generate)
def test(self):
np.random.seed(123)
- models = [2, 3]
- paddings = [1, 2]
- strides = [1, 3]
-
- model_ori = Model()
- data = torch.load('data/maxpool_test_data')
- model_ori.load_state_dict(data['model'])
N = 2
- n_classes = 10
- image = data['input']
- # image = torch.rand([N,1,28,28])
- # image = image.to(torch.float32) / 255.0
- model = BoundedModule(model_ori, image, device="cpu", bound_opts={"conv_mode": "matrix"})
- eps = 0.3
- norm = np.inf
- ptb = PerturbationLpNorm(norm=norm, eps=eps)
- image = BoundedTensor(image, ptb)
- pred = model(image)
- lb, ub = model.compute_bounds()
+ for kernel_size in [3,4]:
+ for padding in [0]:
+ for conv_padding in [0,1]:
+ print(kernel_size, padding, kernel_size, conv_padding)
+
+ model_ori = Model(kernel_size=kernel_size, padding=padding, stride=kernel_size, conv_padding=conv_padding)
+ if not self.generate:
+ data = torch.load('data/maxpool_test_data_{}-{}-{}-{}'.format(kernel_size, padding, kernel_size, conv_padding))
+ image = data['input']
+ model_ori(image)
+ model_ori.load_state_dict(data['model'])
+ else:
+ image = torch.rand([N, 1, 28, 28])
+ model_ori(image)
+
+
+ if self.generate:
+ conv_mode = "matrix"
+ else:
+ conv_mode = "patches"
+
+ model = BoundedModule(model_ori, image, device="cpu", bound_opts={"conv_mode": conv_mode})
+ eps = 0.3
+ norm = np.inf
+ ptb = PerturbationLpNorm(norm=norm, eps=eps)
+ image = BoundedTensor(image, ptb)
+ pred = model(image)
+
+ lb, ub = model.compute_bounds()
- lb_ref = data['lb']
- ub_ref = data['ub']
- assert torch.allclose(lb, lb_ref, 1e-4)
- assert torch.allclose(ub, ub_ref, 1e-4)
+ if self.generate:
+ torch.save(
+ {'model': model_ori.state_dict(),
+ 'input': image,
+ 'lb': lb,
+ 'ub': ub}, 'data/maxpool_test_data_{}-{}-{}-{}'.format(kernel_size, padding, kernel_size, conv_padding)
+ )
- # lb, ub = model.compute_bounds(x=(image,), method="CROWN-Optimized")
+ if not self.generate:
+ lb_ref = data['lb']
+ ub_ref = data['ub']
- # torch.save({'input': image, 'model': model_ori.state_dict(), 'lb': lb, 'ub': ub}, 'data/maxpool_test_data')
+ assert torch.allclose(lb, lb_ref, 1e-4)
+ assert torch.allclose(ub, ub_ref, 1e-4)
if __name__ == '__main__':
- testcase = TestConv()
+ testcase = TestMaxPool(generate=True)
testcase.test()
diff --git a/tests/test_relative_ibp.py b/tests/test_relative_ibp.py
deleted file mode 100644
index 9eb99ff..0000000
--- a/tests/test_relative_ibp.py
+++ /dev/null
@@ -1,40 +0,0 @@
-"""Test IBP with relative bounds"""
-import torch
-import os
-import torch.nn as nn
-import torch.nn.functional as F
-from auto_LiRPA import BoundedModule, BoundedTensor
-from auto_LiRPA.perturbations import *
-from auto_LiRPA.bound_ops import *
-from testcase import TestCase
-import sys
-sys.path.append('../examples/vision')
-from models import *
-
-class TestRelativeIBP(TestCase):
- def __init__(self, methodName='runTest', generate=False):
- super().__init__(methodName)
-
- def test(self):
- dummy_input = torch.randn(1, 3, 32, 32)
-
- model_ori = cnn_6layer(in_ch=3, in_dim=32)
- model = BoundedModule(model_ori, dummy_input, bound_opts={ 'ibp_relative': True })
-
- model_ori_ref = cnn_6layer(in_ch=3, in_dim=32)
- model_ori_ref.load_state_dict(model_ori.state_dict())
- model_ref = BoundedModule(model_ori_ref, dummy_input, bound_opts={ 'ibp_relative': False })
-
- eps = 1e-1
- data = torch.randn(8, 3, 32, 32)
- data_lb, data_ub = data - eps, data + eps
- ptb = PerturbationLpNorm(norm=np.inf, eps=eps, x_L=data_lb, x_U=data_ub, relative=True)
- x = (BoundedTensor(data, ptb),)
-
- fv = model(x)
- fv_ref = model_ref(x)
- lb, ub = model.compute_bounds(method='IBP')
- lb_ref, ub_ref = model_ref.compute_bounds(method='IBP')
-
- self.assertEqual(lb, lb_ref)
- self.assertEqual(ub, ub_ref)
diff --git a/tests/test_simple_verification.py b/tests/test_simple_verification.py
new file mode 100644
index 0000000..f648611
--- /dev/null
+++ b/tests/test_simple_verification.py
@@ -0,0 +1,55 @@
+"""Test optimized bounds in simple_verification."""
+import torch
+import torch.nn as nn
+import torchvision
+from auto_LiRPA import BoundedModule, BoundedTensor
+from auto_LiRPA.perturbations import PerturbationLpNorm
+from auto_LiRPA.utils import Flatten
+from testcase import TestCase
+
+# This simple model comes from https://github.com/locuslab/convex_adversarial
+def mnist_model():
+ model = nn.Sequential(
+ nn.Conv2d(1, 16, 4, stride=2, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(16, 32, 4, stride=2, padding=1),
+ nn.ReLU(),
+ Flatten(),
+ nn.Linear(32*7*7,100),
+ nn.ReLU(),
+ nn.Linear(100, 10)
+ )
+ return model
+
+class TestSimpleVerification(TestCase):
+ def __init__(self, methodName='runTest'):
+ super().__init__(methodName)
+
+ def test(self):
+ model = mnist_model()
+ checkpoint = torch.load(
+ '../examples/vision/pretrain/mnist_a_adv.pth',
+ map_location=torch.device('cpu'))
+ model.load_state_dict(checkpoint)
+
+ test_data = torchvision.datasets.MNIST(
+ './data', train=False, download=True, transform=torchvision.transforms.ToTensor())
+ N = 2
+ image = test_data.data[:N].view(N,1,28,28)
+ image = image.to(torch.float32) / 255.0
+ if torch.cuda.is_available():
+ image = image.cuda()
+ model = model.cuda()
+
+ lirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device)
+ ptb = PerturbationLpNorm(0.3)
+ image = BoundedTensor(image, ptb)
+
+ method = 'CROWN-Optimized (alpha-CROWN)'
+ lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}})
+ _, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0])
+ self.assertTensorEqual(ub[0][7], torch.tensor(12.5080))
+
+if __name__ == '__main__':
+ testcase = TestSimpleVerification()
+ testcase.test()
diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py
index 4ac7b25..7a5019a 100644
--- a/tests/test_vision_models.py
+++ b/tests/test_vision_models.py
@@ -1,4 +1,5 @@
-import random
+import torch
+import torch.nn as nn
import torch.nn.functional as F
from auto_LiRPA import BoundedModule, BoundedTensor
from auto_LiRPA.perturbations import *
@@ -23,26 +24,28 @@ def forward(self, x):
return x
-class TestVisionModels(TestCase):
+class TestVisionModels(TestCase):
def __init__(self, methodName='runTest', generate=False):
- super().__init__(methodName, seed=1234, ref_path='data/vision_test_data')
+ super().__init__(methodName, seed=1234,
+ ref_path='data/vision_test_data', generate=generate)
self.result = {}
def verify_bounds(self, model, x, IBP, method, forward_ret, lb_name, ub_name):
lb, ub = model(method_opt="compute_bounds", x=(x,), IBP=IBP, method=method)
self.result[lb_name] = lb
self.result[ub_name] = ub
- assert torch.allclose(lb, self.reference[lb_name], 1e-4), (lb - self.reference[lb_name]).abs().sum()
- assert torch.allclose(ub, self.reference[ub_name], 1e-4), (ub - self.reference[ub_name]).abs().sum()
- assert ((lb - self.reference[lb_name]).pow(2).sum() < 1e-9), (lb - self.reference[lb_name]).pow(2).sum()
- assert ((ub - self.reference[ub_name]).pow(2).sum() < 1e-9), (ub - self.reference[ub_name]).pow(2).sum()
# test gradient backward propagation
loss = (ub - lb).abs().sum()
loss.backward()
grad = x.grad
- self.result[lb_name[:-2] + 'grad'] = grad
- assert torch.allclose(grad, self.reference[lb_name[:-2] + 'grad'], 1e-4, 1e-6)
- assert (grad - self.reference[lb_name[:-2] + 'grad']).pow(2).sum() < 1e-9
+ self.result[lb_name[:-2] + 'grad'] = grad.clone()
+ if not self.generate:
+ assert torch.allclose(lb, self.reference[lb_name], 1e-4), (lb - self.reference[lb_name]).abs().sum()
+ assert torch.allclose(ub, self.reference[ub_name], 1e-4), (ub - self.reference[ub_name]).abs().sum()
+ assert ((lb - self.reference[lb_name]).pow(2).sum() < 1e-9), (lb - self.reference[lb_name]).pow(2).sum()
+ assert ((ub - self.reference[ub_name]).pow(2).sum() < 1e-9), (ub - self.reference[ub_name]).pow(2).sum()
+ assert torch.allclose(grad, self.reference[lb_name[:-2] + 'grad'], 1e-4, 1e-6)
+ assert (grad - self.reference[lb_name[:-2] + 'grad']).pow(2).sum() < 1e-9
def test_bounds(self):
np.random.seed(123) # FIXME inconsistent seeds
@@ -83,6 +86,13 @@ def test_bounds(self):
ub_name='l_2_CROWN_ub') # CROWN
if self.generate:
- for item in self.result:
- self.reference = self.result[item]
+ self.result['data'] = self.reference['data']
+ self.result['model'] = self.reference['model']
self.save()
+
+
+if __name__ =="__main__":
+ # t = TestVisionModels(generate=True)
+ t = TestVisionModels()
+ t.setUp()
+ t.test_bounds()
\ No newline at end of file
diff --git a/tests/test_weight_perturbation.py b/tests/test_weight_perturbation.py
index 073f7b4..3cc7ba0 100644
--- a/tests/test_weight_perturbation.py
+++ b/tests/test_weight_perturbation.py
@@ -1,28 +1,26 @@
import copy
-import random
-import argparse
-import torch.nn.functional as F
import subprocess
import numpy as np
from testcase import TestCase
import sys
sys.path.append('../examples/vision')
import models
-from auto_LiRPA import BoundedModule, BoundedParameter
+from auto_LiRPA import BoundedModule
from auto_LiRPA.perturbations import *
-class TestWeightPerturbation(TestCase):
+class TestWeightPerturbation(TestCase):
def __init__(self, methodName='runTest', generate=False):
super().__init__(methodName, seed=1234, ref_path='data/weight_perturbation_test_data')
self.result = {}
def test_training(self):
+ # python weight_perturbation_training.py --device cpu --scheduler_opts start=1,length=100 --num_epochs 1 --truncate_data 5
ret = subprocess.run(
- ['python', 'weight_perturbation_training.py',
+ ['python', 'weight_perturbation_training.py',
'--device', 'cpu',
'--scheduler_opts', 'start=1,length=100',
- '--num_epochs', '1',
+ '--num_epochs', '1',
'--truncate_data', '5'],
cwd='../examples/vision', capture_output=True)
self.assertEqual(ret.returncode, 0, ret.stderr)
@@ -56,11 +54,12 @@ def test_perturbation(self):
self.result['model'] = model_ori.state_dict()
self.result['data'] = torch.randn(8, 1, 28, 28)
model_ori.load_state_dict(self.result['model'])
- state_dict = copy.deepcopy(model_ori.state_dict())
+ state_dict = copy.deepcopy(model_ori.state_dict())
dummy_input = self.result['data'].requires_grad_()
inputs = (dummy_input,)
- model = BoundedModule(model_ori, inputs)
+ model = BoundedModule(model_ori, inputs, bound_opts={
+ 'sparse_intermediate_bounds': False, 'sparse_conv_intermediate_bounds': False, 'sparse_intermediate_bounds_with_ibp': False})
forward_ret = model(dummy_input)
model_ori.eval()
@@ -69,7 +68,8 @@ def test_perturbation(self):
def verify_model(pert_weight=True, pert_bias=True, norm=np.inf, lb_name='', ub_name=''):
model_ori_ = models.Models['mlp_3layer_weight_perturb'](pert_weight=pert_weight, pert_bias=pert_bias, norm=norm).eval()
model_ori_.load_state_dict(state_dict)
- model_ = BoundedModule(model_ori_, inputs)
+ model_ = BoundedModule(model_ori_, inputs, bound_opts={
+ 'sparse_intermediate_bounds': False, 'sparse_conv_intermediate_bounds': False, 'sparse_intermediate_bounds_with_ibp': False})
model_.ptb = model_ori.ptb
self.verify_bounds(model_, dummy_input, IBP=True, method='backward', forward_ret=forward_ret,