diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml deleted file mode 100644 index 9af86dc..0000000 --- a/.github/workflows/CI.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: auto_LiRPA CI -on: [push] -jobs: - Tests: - runs-on: ubuntu-latest - steps: - - name: Create swap - run: | - sudo fallocate -l 16G /swapfile - sudo chmod 600 /swapfile - sudo mkswap /swapfile - sudo swapon /swapfile - free -h - - name: Setup Python - uses: actions/setup-python@v2.2.2 - with: - python-version: 3.8 - architecture: x64 - - name: Check out repository code - uses: actions/checkout@v2 - - name: Install auto_LiRPA - run: python setup.py install - - name: Install dependencies for examples - run: | - cd examples - pip install -r requirements.txt - cd .. - - name: Run tests - run: | - cd tests - python utils/download_models.py - pytest diff --git a/.gitignore b/.gitignore index 4197b11..3ace7db 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,16 @@ __pycache__ *.egg-info dist *.swp +*.swo *.log -.trace_graph \ No newline at end of file +.trace_graph +Verified_ret*.npy +Verified-acc*.npy +vnn-comp_*.npz +*.tar.gz +verifier_log_* +*.pth +*.pt +.idea +*.so +release diff --git a/.travis.yml b/.travis.yml index c9418c8..8c80a27 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,10 @@ python: - "3.8" install: - pip install --editable . - - cd examples + - cd examples - pip install -r requirements.txt + - pip install git+https://github.com/KaidiXu/onnx2pytorch.git + - pip install onnxruntime - cd .. - sudo fallocate -l 16G /swapfile - sudo chmod 600 /swapfile @@ -14,4 +16,5 @@ install: script: - cd tests - python utils/download_models.py - - pytest + - pytest + - cd .. \ No newline at end of file diff --git a/README.md b/README.md index 4a4121d..42c2200 100644 --- a/README.md +++ b/README.md @@ -13,16 +13,17 @@ ## What's New? -- Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) **won** [VNN-COMP 2021](https://sites.google.com/view/vnn2021) **with the highest total score**, outperforming 11 SOTA verifiers. α,β-CROWN uses the `auto_LiRPA` library as its core bound computation library. -- Support for [custom operators](https://auto-lirpa.readthedocs.io/en/latest/custom_op.html). (01/02/2022) +- Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) (using `auto_LiRPA` as its core library) **won** [VNN-COMP 2022](https://sites.google.com/view/vnn2022). Our library supports the large CIFAR100, TinyImageNet and ImageNet models in VNN-COMP 2022. (09/2022) +- Implementation of **general cutting planes** ([GCP-CROWN](https://arxiv.org/pdf/2208.05740.pdf)), support of more activation functions and improved performance and scalability. (09/2022) +- Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) **won** [VNN-COMP 2021](https://sites.google.com/view/vnn2021) **with the highest total score**, outperforming 11 SOTA verifiers. α,β-CROWN uses the `auto_LiRPA` library as its core bound computation library. (09/2021) - [Optimized CROWN/LiRPA](https://arxiv.org/pdf/2011.13824.pdf) bound (α-CROWN) for ReLU, **sigmoid**, **tanh**, and **maxpool** activation functions, which can significantly outperform regular CROWN bounds. See [simple_verification.py](examples/vision/simple_verification.py#L59) for an example. (07/31/2021) - Handle split constraints for ReLU neurons ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf)) for complete verifiers. (07/31/2021) -- A memory efficient GPU implementation of backward (CROWN) bounds for +- A memory efficient GPU implementation of backward (CROWN) bounds for convolutional layers. (10/31/2020) - Certified defense models for downscaled ImageNet, TinyImageNet, CIFAR-10, LSTM/Transformer. (08/20/2020) - Adding support to **complex vision models** including DenseNet, ResNeXt and WideResNet. (06/30/2020) -- **Loss fusion**, a technique that reduces training cost of tight LiRPA bounds -(e.g. CROWN-IBP) to the same asympototic complexity of IBP, making LiRPA based certified +- **Loss fusion**, a technique that reduces training cost of tight LiRPA bounds +(e.g. CROWN-IBP) to the same asympototic complexity of IBP, making LiRPA based certified defense scalable to large datasets (e.g., TinyImageNet, downscaled ImageNet). (06/30/2020) - **Multi-GPU** support to scale LiRPA based training to large models and datasets. (06/30/2020) - Initial release. (02/28/2020) @@ -46,6 +47,7 @@ Our library supports the following algorithms: * Backward mode LiRPA bound propagation ([CROWN](https://arxiv.org/pdf/1811.00866.pdf)/[DeepPoly](https://files.sri.inf.ethz.ch/website/papers/DeepPoly.pdf)) * Backward mode LiRPA bound propagation with optimized bounds ([α-CROWN](https://arxiv.org/pdf/2011.13824.pdf)) * Backward mode LiRPA bound propagation with split constraints ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf)) +* Generalized backward mode LiRPA bound propagation with general cutting plane constraints ([GCP-CROWN](https://arxiv.org/pdf/2208.05740.pdf)) * Forward mode LiRPA bound propagation ([Xu et al., 2020](https://arxiv.org/pdf/2002.12920)) * Forward mode LiRPA bound propagation with optimized bounds (similar to [α-CROWN](https://arxiv.org/pdf/2011.13824.pdf)) * Interval bound propagation ([IBP](https://arxiv.org/pdf/1810.12715.pdf)) @@ -89,9 +91,11 @@ user-defined ranges. We get guaranteed output ranges (bounds): ## Installation -Python 3.7+ is required. Pytorch 1.8 (LTS) is recommended, although a newer -version might also work. It is highly recommended to have a pre-installed PyTorch -that matches your system and our version requirement. See [PyTorch Get Started](https://pytorch.org/get-started). +Python 3.7+ and PyTorch 1.8+ are required. +PyTorch 1.11 is recommended, although other recent versions might also work. +It is highly recommended to have a pre-installed PyTorch +that matches your system and our version requirement. +See [PyTorch Get Started](https://pytorch.org/get-started). Then you can install `auto_LiRPA` via: ```bash @@ -102,6 +106,11 @@ python setup.py install If you intend to modify this library, use `python setup.py develop` instead. +Optionally, you may build and install native CUDA modules (CUDA toolkit required): +```bash +python auto_LiRPA/cuda_utils.py install +``` + ## Quick Start First define your computation as a `nn.Module` and wrap it using @@ -127,7 +136,7 @@ my_input = BoundedTensor(my_input, ptb) # Regular forward propagation using BoundedTensor works as usual. prediction = model(my_input) # Compute LiRPA bounds using the backward mode bound propagation (CROWN). -lb, ub = model.compute_bounds(x=(my_input,), method="CROWN") +lb, ub = model.compute_bounds(x=(my_input,), method="backward") ``` Checkout @@ -142,7 +151,7 @@ obtaining gradients through autodiff. Bounds are efficiently computed on GPUs. ## More Working Examples -We provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA`: +We provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA`: * [Basic Bound Computation and **Robustness Verification** of Neural Networks](doc/src/examples.md#basic-bound-computation-and-robustness-verification-of-neural-networks) * [Basic **Certified Adversarial Defense** Training](doc/src/examples.md#basic-certified-adversarial-defense-training) @@ -151,6 +160,10 @@ We provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA` * [Certifiably Robust Language Classifier using **Transformers**](doc/src/examples.md#certifiably-robust-language-classifier-with-transformer-and-lstm) * [Certified Robustness against **Model Weight Perturbations**](doc/src/examples.md#certified-robustness-against-model-weight-perturbations-and-certified-defense) +`auto_LiRPA` has also be used in the following works: +* [**α,β-CROWN for complete neural network verification**](https://github.com/huanzhang12/alpha-beta-CROWN) +* [**Fast certified robust training**](https://github.com/shizhouxing/Fast-Certified-Robust-Training) + ## Full Documentations For more documentations, please refer to: @@ -166,20 +179,27 @@ Please kindly cite our papers if you use the `auto_LiRPA` library. Full [BibTeX The general LiRPA based bound propagation algorithm was originally proposed in our paper: -* [Automatic Perturbation Analysis for Scalable Certified Robustness and Beyond](https://arxiv.org/pdf/2002.12920). -NeurIPS 2020 +* [Automatic Perturbation Analysis for Scalable Certified Robustness and Beyond](https://arxiv.org/pdf/2002.12920). +NeurIPS 2020 Kaidi Xu\*, Zhouxing Shi\*, Huan Zhang\*, Yihan Wang, Kai-Wei Chang, Minlie Huang, Bhavya Kailkhura, Xue Lin, Cho-Jui Hsieh (\* Equal contribution) -The `auto_LiRPA` library is further extended to allow optimized bound (α-CROWN) and split constraints (β-CROWN): +The `auto_LiRPA` library is further extended to allow optimized bound (α-CROWN), split constraints (β-CROWN) and general constraints (GCP-CROWN): -* [Fast and Complete: Enabling Complete Neural Network Verification with Rapid and Massively Parallel Incomplete Verifiers](https://arxiv.org/pdf/2011.13824.pdf). -ICLR 2021 -Kaidi Xu\*, Huan Zhang\*, Shiqi Wang, Yihan Wang, Suman Jana, Xue Lin and Cho-Jui Hsieh (\* Equal contribution) +* [Fast and Complete: Enabling Complete Neural Network Verification with Rapid and Massively Parallel Incomplete Verifiers](https://arxiv.org/pdf/2011.13824.pdf). +ICLR 2021. +Kaidi Xu\*, Huan Zhang\*, Shiqi Wang, Yihan Wang, Suman Jana, Xue Lin and Cho-Jui Hsieh (\* Equal contribution). -* [Beta-CROWN: Efficient Bound Propagation with Per-neuron Split Constraints for Complete and Incomplete Neural Network Verification](https://arxiv.org/pdf/2103.06624.pdf). -NeurIPS 2021 -Shiqi Wang\*, Huan Zhang\*, Kaidi Xu\*, Suman Jana, Xue Lin, Cho-Jui Hsieh and Zico Kolter (\* Equal contribution) +* [Beta-CROWN: Efficient Bound Propagation with Per-neuron Split Constraints for Complete and Incomplete Neural Network Verification](https://arxiv.org/pdf/2103.06624.pdf). +NeurIPS 2021. +Shiqi Wang\*, Huan Zhang\*, Kaidi Xu\*, Suman Jana, Xue Lin, Cho-Jui Hsieh and Zico Kolter (\* Equal contribution). +* [GCP-CROWN: General Cutting Planes for Bound-Propagation-Based Neural Network Verification](https://arxiv.org/abs/2208.05740). +Huan Zhang\*, Shiqi Wang\*, Kaidi Xu\*, Linyi Li, Bo Li, Suman Jana, Cho-Jui Hsieh and Zico Kolter (\* Equal contribution). + +Certified robust training using `auto_LiRPA` is improved to allow much shorter warmup and faster training: +* [Fast Certified Robust Training with Short Warmup](https://arxiv.org/pdf/2103.17268.pdf). +NeurIPS 2021. +Zhouxing Shi\*, Yihan Wang\*, Huan Zhang, Jinfeng Yi and Cho-Jui Hsieh (\* Equal contribution). ## Developers and Copyright @@ -187,12 +207,20 @@ Shiqi Wang\*, Huan Zhang\*, Kaidi Xu\*, Suman Jana, Xue Lin, Cho-Jui Hsieh and Z |:--:|:--:| :--:| :--:| :--:| | | | | | | -* Kaidi Xu (xu.kaid@northeastern.edu): main developer -* Zhouxing Shi (zshi@cs.ucla.edu): main developer -* Huan Zhang (huan@huan-zhang.com): team lead -* Yihan Wang (yihanwang@ucla.edu) -* Shiqi Wang (sw3215@columbia.edu): contact for beta-CROWN +Team lead: +* Huan Zhang (huan@huan-zhang.com), CMU + +Main developers: +* Zhouxing Shi (zshi@cs.ucla.edu), UCLA +* Kaidi Xu (kx46@drexel.edu), Drexel University + +Contributors: +* Yihan Wang (yihanwang@ucla.edu), UCLA +* Shiqi Wang (sw3215@columbia.edu), Columbia University +* Linyi Li (linyi2@illinois.edu), UIUC +* Jinqi (Kathryn) Chen (jinqic@cs.cmu.edu), CMU +* Zhuolin Yang (zhuolin5@illinois.edu), UIUC -We thank [commits](https://github.com/KaidiXu/auto_LiRPA/commits) and [pull requests](https://github.com/KaidiXu/auto_LiRPA/pulls) from community contributors. +We thank the[commits](https://github.com/KaidiXu/auto_LiRPA/commits) and [pull requests](https://github.com/KaidiXu/auto_LiRPA/pulls) from community contributors. Our library is released under the BSD 3-Clause license. diff --git a/auto_LiRPA/__init__.py b/auto_LiRPA/__init__.py index f0b5922..3fbdb43 100644 --- a/auto_LiRPA/__init__.py +++ b/auto_LiRPA/__init__.py @@ -1,7 +1,8 @@ -from .bound_general import BoundedModule, BoundDataParallel +from .bound_general import BoundedModule +from .bound_multi_gpu import BoundDataParallel from .bounded_tensor import BoundedTensor, BoundedParameter from .perturbations import PerturbationLpNorm, PerturbationSynonym from .wrapper import CrossEntropyWrapper, CrossEntropyWrapperMultiInput from .bound_op_map import register_custom_op, unregister_custom_op -__version__ = '0.2' \ No newline at end of file +__version__ = '0.3' diff --git a/auto_LiRPA/adam_element_lr.py b/auto_LiRPA/adam_element_lr.py deleted file mode 100644 index 8324110..0000000 --- a/auto_LiRPA/adam_element_lr.py +++ /dev/null @@ -1,179 +0,0 @@ -import torch -import math -from torch.optim.optimizer import Optimizer -from torch import Tensor -from typing import List, Optional - - -def adam(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[int], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - lr_scale: Optional[Tensor], - batch_dim: Optional[int]): - r"""Functional API that performs Adam algorithm computation. - See :class:`~torch.optim.Adam` for details. - """ - - for i, param in enumerate(params): - - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step = state_steps[i] - - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step - - if weight_decay != 0: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) - # Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) - else: - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - if lr_scale is not None: - # Per batch element learning rate scaler. - # We must know which dimension is corresponding to batch and broadcast accordingly. - total_dim = exp_avg.ndim - new_shape = (1, ) * batch_dim + (lr_scale.size(0), ) + (1, ) * (total_dim - 1 - batch_dim) - scaler = lr_scale.view(*new_shape) - param.addcdiv_(scaler * exp_avg, denom, value=-step_size) - else: - param.addcdiv_(exp_avg, denom, value=-step_size) - if lr_scale is not None: - pass - # print('lr scaler', lr_scale) - - -class AdamElementLR(Optimizer): - r"""Implements Adam algorithm, with the capability of setting different lr - per batch element. - It has been proposed in `Adam: A Method for Stochastic Optimization`_. - The implementation of the L2 penalty follows changes proposed in - `Decoupled Weight Decay Regularization`_. - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad) - super(AdamElementLR, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Adam, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - @torch.no_grad() - def step(self, lr_scale=None, closure=None): - """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for i, group in enumerate(self.param_groups): - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - max_exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is not None: - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - grads.append(p.grad) - - state = self.state[p] - # Lazy state initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) - if group['amsgrad']: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - - if group['amsgrad']: - max_exp_avg_sqs.append(state['max_exp_avg_sq']) - - # update the steps for each param group update - state['step'] += 1 - # record the step after step update - state_steps.append(state['step']) - - adam(params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=group['amsgrad'], - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps'], - lr_scale=lr_scale[i] if lr_scale is not None else None, - batch_dim=group['batch_dim'], - ) - return loss diff --git a/auto_LiRPA/backward_bound.py b/auto_LiRPA/backward_bound.py new file mode 100644 index 0000000..411ae9e --- /dev/null +++ b/auto_LiRPA/backward_bound.py @@ -0,0 +1,766 @@ +import torch +from torch import Tensor +from collections import deque, defaultdict +from tqdm import tqdm +from .patches import Patches +from .utils import * +from .bound_ops import * +import warnings + + +def batched_backward( + self, node, C, unstable_idx, batch_size, bound_lower=True, + bound_upper=True): + crown_batch_size = self.bound_opts['crown_batch_size'] + unstable_size = get_unstable_size(unstable_idx) + print(f'Batched CROWN: unstable size {unstable_size}') + num_batches = (unstable_size + crown_batch_size - 1) // crown_batch_size + output_shape = node.output_shape[1:] + dim = int(prod(output_shape)) + ret = [] + for i in tqdm(range(num_batches)): + if isinstance(unstable_idx, tuple): + unstable_idx_batch = tuple( + u[i*crown_batch_size:(i+1)*crown_batch_size] + for u in unstable_idx + ) + unstable_size_batch = len(unstable_idx_batch[0]) + else: + unstable_idx_batch = unstable_idx[i*crown_batch_size:(i+1)*crown_batch_size] + unstable_size_batch = len(unstable_idx_batch) + if node.patches_start and node.mode == "patches": + assert C in ['Patches', None] + C_batch = Patches(shape=[ + unstable_size_batch, batch_size, *node.output_shape[1:-2], 1, 1], + identity=1, unstable_idx=unstable_idx_batch, + output_shape=[batch_size, *node.output_shape[1:]]) + elif isinstance(node, BoundLinear) or isinstance(node, BoundMatMul): + assert C in ['OneHot', None] + C_batch = OneHotC( + [batch_size, unstable_size_batch, *node.output_shape[1:]], + self.device, unstable_idx_batch, None) + else: + assert C in ['eye', None] + C_batch = torch.zeros([1, unstable_size_batch, dim], device=self.device) + C_batch[0, torch.arange(unstable_size_batch), unstable_idx_batch] = 1.0 + C_batch = C_batch.expand(batch_size, -1, -1).view( + batch_size, unstable_size_batch, *output_shape) + ret.append(self.backward_general( + C=C_batch, node=node, + bound_lower=bound_lower, bound_upper=bound_upper, + average_A=False, need_A_only=False, unstable_idx=unstable_idx_batch, + verbose=False)) + if bound_lower: + lb = torch.cat([item[0].view(batch_size, -1) for item in ret], dim=1) + else: + lb = None + if bound_upper: + ub = torch.cat([item[1].view(batch_size, -1) for item in ret], dim=1) + else: + ub = None + return lb, ub + + +def backward_general( + self, C, node=None, bound_lower=True, bound_upper=True, average_A=False, + need_A_only=False, unstable_idx=None, unstable_size=0, update_mask=None, verbose=True): + if verbose: + logger.debug(f'Bound backward from {node.__class__.__name__}({node.name})') + if isinstance(C, str): + logger.debug(f' C: {C}') + elif C is not None: + logger.debug(f' C: shape {C.shape}, type {type(C)}') + _print_time = False + + if isinstance(C, str): + # If C is a str, use batched CROWN. If batched CROWN is not intended to + # be enabled, C must be a explicitly provided non-str object for this function. + if need_A_only or self.return_A or average_A: + raise ValueError( + 'Batched CROWN is not compatible with ' + f'need_A_only={need_A_only}, return_A={self.return_A}, ' + f'average_A={average_A}') + node.lower, node.upper = self.batched_backward( + node, C, unstable_idx, + batch_size=self.root[0].value.shape[0], + bound_lower=bound_lower, bound_upper=bound_upper, + ) + return node.lower, node.upper + + for l in self._modules.values(): + l.lA = l.uA = None + l.bounded = True + + degree_out = get_degrees(node, self.backward_from) + all_nodes_before = list(degree_out.keys()) + + C, batch_size, output_dim, output_shape = preprocess_C(C, node) + + node.lA = C if bound_lower else None + node.uA = C if bound_upper else None + lb = ub = torch.tensor(0., device=self.device) + + + # Save intermediate layer A matrices when required. + A_record = {} + + queue = deque([node]) + while len(queue) > 0: + l = queue.popleft() # backward from l + l.bounded = True + + if l.name in self.root_name: continue + + # if all the succeeds are done, then we can turn to this node in the + # next iteration. + for l_pre in l.inputs: + degree_out[l_pre.name] -= 1 + if degree_out[l_pre.name] == 0: + queue.append(l_pre) + + # Initially, l.lA or l.uA will be set to C for this node. + if l.lA is not None or l.uA is not None: + if verbose: + logger.debug(f' Bound backward to {l}' + f' (lA shape {l.lA.shape if l.lA is not None else None},' + f' uA shape {l.uA.shape if l.uA is not None else None},' + f' out shape {l.output_shape})') + + if _print_time: + start_time = time.time() + + if not l.perturbed: + if not hasattr(l, 'forward_value'): + self.get_forward_value(l) + lb, ub = add_constant_node(lb, ub, l) + continue + + if l.zero_uA_mtx and l.zero_lA_mtx: + # A matrices are all zero, no need to propagate. + continue + + if isinstance(l, BoundRelu): + # TODO: unify this interface. + A, lower_b, upper_b = l.bound_backward( + l.lA, l.uA, *l.inputs, start_node=node, unstable_idx=unstable_idx, + beta_for_intermediate_layers=self.intermediate_constr is not None) + elif isinstance(l, BoundOptimizableActivation): + # For other optimizable activation functions (TODO: unify with ReLU). + if node.name != self.final_node_name: + start_shape = node.output_shape[1:] + else: + start_shape = C.shape[0] + A, lower_b, upper_b = l.bound_backward( + l.lA, l.uA, *l.inputs, start_shape=start_shape, start_node=node) + else: + A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *l.inputs) + # After propagation through this node, we delete its lA, uA variables. + if not self.return_A and node.name != self.final_name: + del l.lA, l.uA + if _print_time: + time_elapsed = time.time() - start_time + if time_elapsed > 1e-3: + print(l, time_elapsed) + if lb.ndim > 0 and type(lower_b) == Tensor and self.conv_mode == 'patches': + lb, ub, lower_b, upper_b = check_patch_biases(lb, ub, lower_b, upper_b) + lb = lb + lower_b + ub = ub + upper_b + if self.return_A and self.needed_A_dict and node.name in self.needed_A_dict: + # FIXME remove [0][0] and [0][1]? + if len(self.needed_A_dict[node.name]) == 0 or l.name in self.needed_A_dict[node.name]: + A_record.update({l.name: { + "lA": A[0][0].transpose(0, 1).detach() if A[0][0] is not None else None, + "uA": A[0][1].transpose(0, 1).detach() if A[0][1] is not None else None, + # When not used, lb or ub is tensor(0). + "lbias": lb.transpose(0, 1).detach() if lb.ndim > 1 else None, + "ubias": ub.transpose(0, 1).detach() if ub.ndim > 1 else None, + "unstable_idx": unstable_idx + }}) + # FIXME: solve conflict with the following case + self.A_dict.update({node.name: A_record}) + if need_A_only and set(self.needed_A_dict[node.name]) == set(A_record.keys()): + # We have collected all A matrices we need. We can return now! + self.A_dict.update({node.name: A_record}) + # Do not concretize to save time. We just need the A matrices. + # return A matrix as a dict: {node.name: [A_lower, A_upper]} + return None, None, self.A_dict + + + for i, l_pre in enumerate(l.inputs): + add_bound(l, l_pre, lA=A[i][0], uA=A[i][1]) + + if lb.ndim >= 2: + lb = lb.transpose(0, 1) + if ub.ndim >= 2: + ub = ub.transpose(0, 1) + + if self.return_A and self.needed_A_dict and node.name in self.needed_A_dict: + save_A_record( + node, A_record, self.A_dict, self.root, self.needed_A_dict[node.name], + lb=lb, ub=ub, unstable_idx=unstable_idx) + + # TODO merge into `concretize` + if self.cut_used and getattr(self, 'cut_module', None) is not None and self.cut_module.x_coeffs is not None: + # propagate input neuron in cut constraints + self.root[0].lA, self.root[0].uA = self.cut_module.input_cut( + node, self.root[0].lA, self.root[0].uA, self.root[0].lower.size()[1:], unstable_idx, + batch_mask=update_mask) + + lb, ub = concretize( + lb, ub, node, self.root, batch_size, output_dim, + bound_lower, bound_upper, average_A=average_A) + + # TODO merge into `concretize` + if self.cut_used and getattr(self, "cut_module", None) is not None and self.cut_module.cut_bias is not None: + # propagate cut bias in cut constraints + lb, ub = self.cut_module.bias_cut(node, lb, ub, unstable_idx, batch_mask=update_mask) + if lb is not None and ub is not None and ((lb-ub)>0).sum().item() > 0: + # make sure there is no bug for cut constraints propagation + print(f"Warning: lb is larger than ub with diff: {(lb-ub)[(lb-ub)>0].max().item()}") + + lb = lb.view(batch_size, *output_shape) if bound_lower else None + ub = ub.view(batch_size, *output_shape) if bound_upper else None + + if verbose: + logger.debug('') + + if self.return_A: + return lb, ub, self.A_dict + else: + return lb, ub + + +def get_unstable_size(unstable_idx): + if isinstance(unstable_idx, tuple): + return unstable_idx[0].numel() + else: + return unstable_idx.numel() + + +def check_optimized_variable_sparsity(self, node): + alpha_sparsity = None # unknown. + for relu in self.relus: + if hasattr(relu, 'alpha_lookup_idx') and node.name in relu.alpha_lookup_idx: + if relu.alpha_lookup_idx[node.name] is not None: + # This node was created with sparse alpha + alpha_sparsity = True + else: + alpha_sparsity = False + break + # print(f'node {node.name} alpha sparsity {alpha_sparsity}') + return alpha_sparsity + + +def get_sparse_C( + self, node, sparse_intermediate_bounds=True, ref_intermediate_lb=None, + ref_intermediate_ub=None): + sparse_conv_intermediate_bounds = self.bound_opts.get('sparse_conv_intermediate_bounds', False) + minimum_sparsity = self.bound_opts.get('minimum_sparsity', 0.9) + crown_batch_size = self.bound_opts.get('crown_batch_size', 1e9) + dim = int(prod(node.output_shape[1:])) + batch_size = self.batch_size + + reduced_dim = False # Only partial neurons (unstable neurons) are bounded. + unstable_idx = None + unstable_size = np.inf + newC = None + + alpha_is_sparse = self.check_optimized_variable_sparsity(node) + + # NOTE: batched CROWN is so far only supported for some of the cases below + + # FIXME: C matrix shape incorrect for BoundParams. + if (isinstance(node, BoundLinear) or isinstance(node, BoundMatMul)) and int( + os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0: + if sparse_intermediate_bounds: + # If we are doing bound refinement and reference bounds are given, we only refine unstable neurons. + # Also, if we are checking against LP solver we will refine all neurons and do not use this optimization. + # For each batch element, we find the unstable neurons. + unstable_idx, unstable_size = self.get_unstable_locations( + ref_intermediate_lb, ref_intermediate_ub) + if unstable_size == 0: + # Do nothing, no bounds will be computed. + reduced_dim = True + unstable_idx = [] + elif unstable_size > crown_batch_size: + # Create C in batched CROWN + newC = 'OneHot' + reduced_dim = True + elif (unstable_size <= minimum_sparsity * dim and unstable_size > 0 and alpha_is_sparse is None) or alpha_is_sparse: + # When we already have sparse alpha for this layer, we always use sparse C. Otherwise we determine it by sparsity. + # Create an abstract C matrix, the unstable_idx are the non-zero elements in specifications for all batches. + newC = OneHotC( + [batch_size, unstable_size, *node.output_shape[1:]], + self.device, unstable_idx, None) + reduced_dim = True + else: + unstable_idx = None + del ref_intermediate_lb, ref_intermediate_ub + if not reduced_dim: + newC = eyeC([batch_size, dim, *node.output_shape[1:]], self.device) + elif node.patches_start and node.mode == "patches": + if sparse_intermediate_bounds: + unstable_idx, unstable_size = self.get_unstable_locations( + ref_intermediate_lb, ref_intermediate_ub, conv=True) + if unstable_size == 0: + # Do nothing, no bounds will be computed. + reduced_dim = True + unstable_idx = [] + elif unstable_size > crown_batch_size: + # Create C in batched CROWN + newC = 'Patches' + reduced_dim = True + # We sum over the channel direction, so need to multiply that. + elif (sparse_conv_intermediate_bounds and unstable_size <= minimum_sparsity * dim and alpha_is_sparse is None) or alpha_is_sparse: + # When we already have sparse alpha for this layer, we always use sparse C. Otherwise we determine it by sparsity. + # Create an abstract C matrix, the unstable_idx are the non-zero elements in specifications for all batches. + # The shape of patches is [unstable_size, batch, C, H, W]. + newC = Patches( + shape=[unstable_size, batch_size, *node.output_shape[1:-2], 1, 1], + identity=1, unstable_idx=unstable_idx, + output_shape=[batch_size, *node.output_shape[1:]]) + reduced_dim = True + else: + unstable_idx = None + del ref_intermediate_lb, ref_intermediate_ub + # Here we create an Identity Patches object + if not reduced_dim: + newC = Patches( + None, 1, 0, [node.output_shape[1], batch_size, *node.output_shape[2:], + *node.output_shape[1:-2], 1, 1], 1, + output_shape=[batch_size, *node.output_shape[1:]]) + elif isinstance(node, (BoundAdd, BoundSub)) and node.mode == "patches": + # FIXME: BoundAdd does not always have patches. Need to use a better way to determine patches mode. + # FIXME: We should not hardcode BoundAdd here! + if sparse_intermediate_bounds: + if crown_batch_size < 1e9: + warnings.warn('Batched CROWN is not supported in this case') + unstable_idx, unstable_size = self.get_unstable_locations( + ref_intermediate_lb, ref_intermediate_ub, conv=True) + if unstable_size == 0: + # Do nothing, no bounds will be computed. + reduced_dim = True + unstable_idx = [] + elif (sparse_conv_intermediate_bounds and unstable_size <= minimum_sparsity * dim and alpha_is_sparse is None) or alpha_is_sparse: + # When we already have sparse alpha for this layer, we always use sparse C. Otherwise we determine it by sparsity. + num_channel = node.output_shape[-3] + # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1). + patches = ( + torch.eye(num_channel, device=self.device, + dtype=list(self.parameters())[0].dtype)).view( + num_channel, 1, 1, 1, num_channel, 1, 1) + # Expand to (out_c, 1, unstable_size, out_c, 1, 1). + patches = patches.expand(-1, 1, node.output_shape[-2], node.output_shape[-1], -1, 1, 1) + patches = patches[unstable_idx[0], :, unstable_idx[1], unstable_idx[2]] + # Expand with the batch dimension. Final shape (unstable_size, batch_size, out_c, 1, 1). + patches = patches.expand(-1, batch_size, -1, -1, -1) + newC = Patches( + patches, 1, 0, patches.shape, unstable_idx=unstable_idx, + output_shape=[batch_size, *node.output_shape[1:]]) + reduced_dim = True + else: + unstable_idx = None + del ref_intermediate_lb, ref_intermediate_ub + if not reduced_dim: + num_channel = node.output_shape[-3] + # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1). + patches = ( + torch.eye(num_channel, device=self.device, + dtype=list(self.parameters())[0].dtype)).view( + num_channel, 1, 1, 1, num_channel, 1, 1) + # Expand to (out_c, batch, out_h, out_w, out_c, 1, 1). + patches = patches.expand(-1, batch_size, node.output_shape[-2], node.output_shape[-1], -1, 1, 1) + newC = Patches(patches, 1, 0, patches.shape, output_shape=[batch_size, *node.output_shape[1:]]) + else: + if sparse_intermediate_bounds: + unstable_idx, unstable_size = self.get_unstable_locations( + ref_intermediate_lb, ref_intermediate_ub) + if unstable_size == 0: + # Do nothing, no bounds will be computed. + reduced_dim = True + unstable_idx = [] + elif unstable_size > crown_batch_size: + # Create in C in batched CROWN + newC = 'eye' + reduced_dim = True + elif (unstable_size <= minimum_sparsity * dim and alpha_is_sparse is None) or alpha_is_sparse: + newC = torch.zeros([1, unstable_size, dim], device=self.device) + # Fill the corresponding elements to 1.0 + newC[0, torch.arange(unstable_size), unstable_idx] = 1.0 + newC = newC.expand(batch_size, -1, -1).view(batch_size, unstable_size, *node.output_shape[1:]) + reduced_dim = True + else: + unstable_idx = None + del ref_intermediate_lb, ref_intermediate_ub + if not reduced_dim: + if dim > 1000: + warnings.warn(f"Creating an identity matrix with size {dim}x{dim} for node {node}. This may indicate poor performance for bound computation. If you see this message on a small network please submit a bug report.", stacklevel=2) + newC = torch.eye(dim, device=self.device, dtype=list(self.parameters())[0].dtype) \ + .unsqueeze(0).expand(batch_size, -1, -1) \ + .view(batch_size, dim, *node.output_shape[1:]) + + return newC, reduced_dim, unstable_idx, unstable_size + + +def restore_sparse_bounds( + self, node, unstable_idx, unstable_size, ref_intermediate_lb, ref_intermediate_ub, + new_lower=None, new_upper=None): + batch_size = self.batch_size + if unstable_size == 0: + # No unstable neurons. Skip the update. + node.lower = ref_intermediate_lb.detach().clone() + node.upper = ref_intermediate_ub.detach().clone() + else: + if new_lower is None: + new_lower = node.lower + if new_upper is None: + new_upper = node.upper + # If we only calculated unstable neurons, we need to scatter the results back based on reference bounds. + if isinstance(unstable_idx, tuple): + lower = ref_intermediate_lb.detach().clone() + upper = ref_intermediate_ub.detach().clone() + # Conv layer with patches, the unstable_idx is a 3-element tuple for 3 indices (C, H,W) of unstable neurons. + if len(unstable_idx) == 3: + lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = new_lower + upper[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = new_upper + elif len(unstable_idx) == 4: + lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2], unstable_idx[3]] = new_lower + upper[:, unstable_idx[0], unstable_idx[1], unstable_idx[2], unstable_idx[3]] = new_upper + else: + # Other layers. + lower = ref_intermediate_lb.detach().clone().view(batch_size, -1) + upper = ref_intermediate_ub.detach().clone().view(batch_size, -1) + lower[:, unstable_idx] = new_lower.view(batch_size, -1) + upper[:, unstable_idx] = new_upper.view(batch_size, -1) + node.lower = lower.view(batch_size, *node.output_shape[1:]) + node.upper = upper.view(batch_size, *node.output_shape[1:]) + + +def get_degrees(node_start, backward_from): + degrees = {} + queue = deque([node_start]) + node_start.bounded = False + while len(queue) > 0: + l = queue.popleft() + backward_from[l.name].append(node_start) + for l_pre in l.inputs: + degrees[l_pre.name] = degrees.get(l_pre.name, 0) + 1 + if l_pre.bounded: + l_pre.bounded = False + queue.append(l_pre) + return degrees + + +def preprocess_C(C, node): + if isinstance(C, Patches): + if C.unstable_idx is None: + # Patches have size (out_c, batch, out_h, out_w, c, h, w). + if len(C.shape) == 7: + out_c, batch_size, out_h, out_w = C.shape[:4] + output_dim = out_c * out_h * out_w + else: + out_dim, batch_size, out_c, out_h, out_w = C.shape[:5] + output_dim = out_dim * out_c * out_h * out_w + else: + # Patches have size (unstable_size, batch, c, h, w). + output_dim, batch_size = C.shape[:2] + else: + batch_size, output_dim = C.shape[:2] + + # The C matrix specified by the user has shape (batch, spec) but internally we have (spec, batch) format. + if not isinstance(C, (eyeC, Patches, OneHotC)): + C = C.transpose(0, 1) + elif isinstance(C, eyeC): + C = C._replace(shape=(C.shape[1], C.shape[0], *C.shape[2:])) + elif isinstance(C, OneHotC): + C = C._replace( + shape=(C.shape[1], C.shape[0], *C.shape[2:]), + index=C.index.transpose(0,-1), + coeffs=None if C.coeffs is None else C.coeffs.transpose(0,-1)) + + if isinstance(C, Patches) and C.unstable_idx is not None: + # Sparse patches; the output shape is (unstable_size, ). + output_shape = [C.shape[0]] + elif prod(node.output_shape[1:]) != output_dim and not isinstance(C, Patches): + # For the output node, the shape of the bound follows C + # instead of the original output shape + # + # TODO Maybe don't set node.lower and node.upper in this case? + # Currently some codes still depend on node.lower and node.upper + output_shape = [-1] + else: + # Generally, the shape of the bounds match the output shape of the node + output_shape = node.output_shape[1:] + + return C, batch_size, output_dim, output_shape + + +def concretize(lb, ub, node, root, batch_size, output_dim, bound_lower=True, bound_upper=True, average_A=False): + + for i in range(len(root)): + if root[i].lA is None and root[i].uA is None: continue + if average_A and isinstance(root[i], BoundParams): + lA = root[i].lA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].lA.shape) if bound_lower else None + uA = root[i].uA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].uA.shape) if bound_upper else None + else: + lA, uA = root[i].lA, root[i].uA + if not isinstance(root[i].lA, eyeC) and not isinstance(root[i].lA, Patches): + lA = root[i].lA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_lower else None + if not isinstance(root[i].uA, eyeC) and not isinstance(root[i].uA, Patches): + uA = root[i].uA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_upper else None + if hasattr(root[i], 'perturbation') and root[i].perturbation is not None: + if isinstance(root[i], BoundParams): + # add batch_size dim for weights node + lb = lb + root[i].perturbation.concretize( + root[i].center.unsqueeze(0), lA, + sign=-1, aux=root[i].aux) if bound_lower else None + ub = ub + root[i].perturbation.concretize( + root[i].center.unsqueeze(0), uA, + sign=+1, aux=root[i].aux) if bound_upper else None + else: + lb = lb + root[i].perturbation.concretize( + root[i].center, lA, sign=-1, aux=root[i].aux) if bound_lower else None + ub = ub + root[i].perturbation.concretize( + root[i].center, uA, sign=+1, aux=root[i].aux) if bound_upper else None + else: + fv = root[i].forward_value + if type(root[i]) == BoundInput: + # Input node with a batch dimension + batch_size_ = batch_size + else: + # Parameter node without a batch dimension + batch_size_ = 1 + + if bound_lower: + if isinstance(lA, eyeC): + lb = lb + fv.view(batch_size_, -1) + elif isinstance(lA, Patches): + lb = lb + lA.matmul(fv, input_shape=root[0].center.shape) + elif type(root[i]) == BoundInput: + lb = lb + lA.matmul(fv.view(batch_size_, -1, 1)).squeeze(-1) + else: + lb = lb + lA.matmul(fv.view(-1, 1)).squeeze(-1) + else: + lb = None + + if bound_upper: + if isinstance(uA, eyeC): + ub = ub + fv.view(batch_size_, -1) + elif isinstance(uA, Patches): + ub = ub + uA.matmul(fv, input_shape=root[0].center.shape) + elif type(root[i]) == BoundInput: + ub = ub + uA.matmul(fv.view(batch_size_, -1, 1)).squeeze(-1) + else: + ub = ub + uA.matmul(fv.view(-1, 1)).squeeze(-1) + else: + ub = None + + return lb, ub + + +def addA(A1, A2): + """ Add two A (each of them is either Tensor or Patches) """ + if type(A1) == type(A2): + return A1 + A2 + elif type(A1) == Patches: + return A1 + A2 + elif type(A2) == Patches: + return A2 + A1 + else: + raise NotImplementedError(f'Unsupported types for A1 ({type(A1)}) and A2 ({type(A2)}') + + +def add_bound(node, node_pre, lA, uA): + """Propagate lA and uA to a preceding node.""" + if lA is not None: + if node_pre.lA is None: + # First A added to this node. + node_pre.zero_lA_mtx = node.zero_backward_coeffs_l + node_pre.lA = lA + else: + node_pre.zero_lA_mtx = node_pre.zero_lA_mtx and node.zero_backward_coeffs_l + new_node_lA = addA(node_pre.lA, lA) + node_pre.lA = new_node_lA + if uA is not None: + if node_pre.uA is None: + # First A added to this node. + node_pre.zero_uA_mtx = node_pre.zero_backward_coeffs_u + node_pre.uA = uA + else: + node_pre.zero_uA_mtx = node_pre.zero_uA_mtx and node.zero_backward_coeffs_u + node_pre.uA = addA(node_pre.uA, uA) + + +def get_beta_watch_list(intermediate_constr, all_nodes_before): + beta_watch_list = defaultdict(dict) + if intermediate_constr is not None: + # Intermediate layer betas are handled in two cases. + # First, if the beta split is before this node, we don't need to do anything special; + # it will done in BoundRelu. + # Second, if the beta split after this node, we need to modify the A matrix + # during bound propagation to reflect beta after this layer. + for k in intermediate_constr: + if k not in all_nodes_before: + # The second case needs special care: we add all such splits in a watch list. + # However, after first occurance of a layer in the watchlist, + # beta_watch_list will be deleted and the A matrix from split constraints + # has been added and will be propagated to later layers. + for kk, vv in intermediate_constr[k].items(): + beta_watch_list[kk][k] = vv + return beta_watch_list + + +def add_constant_node(lb, ub, node): + new_lb = node.get_bias(node.lA, node.forward_value) + new_ub = node.get_bias(node.uA, node.forward_value) + if isinstance(lb, Tensor) and isinstance(new_lb, Tensor) and lb.ndim > 0 and lb.ndim != new_lb.ndim: + new_lb = new_lb.reshape(lb.shape) + if isinstance(ub, Tensor) and isinstance(new_ub, Tensor) and ub.ndim > 0 and ub.ndim != new_ub.ndim: + new_ub = new_ub.reshape(ub.shape) + lb = lb + new_lb # FIXME (09/16): shape for the bias of BoundConstant. + ub = ub + new_ub + return lb, ub + + +def save_A_record(node, A_record, A_dict, root, needed_A_dict, lb, ub, unstable_idx): + root_A_record = {} + for i in range(len(root)): + if root[i].lA is None and root[i].uA is None: continue + if root[i].name in needed_A_dict: + if root[i].lA is not None: + if isinstance(root[i].lA, Patches): + _lA = root[i].lA + else: + _lA = root[i].lA.transpose(0, 1).detach() + else: + _lA = None + + if root[i].uA is not None: + if isinstance(root[i].uA, Patches): + _uA = root[i].uA + else: + _uA = root[i].uA.transpose(0, 1).detach() + else: + _uA = None + root_A_record.update({root[i].name: { + "lA": _lA, + "uA": _uA, + # When not used, lb or ub is tensor(0). They have been transposed above. + "lbias": lb.detach() if lb.ndim > 1 else None, + "ubias": ub.detach() if ub.ndim > 1 else None, + "unstable_idx": unstable_idx + }}) + root_A_record.update(A_record) # merge to existing A_record + A_dict.update({node.name: root_A_record}) + + +def select_unstable_idx(ref_intermediate_lb, ref_intermediate_ub, unstable_locs, max_crown_size): + """When there are too many unstable neurons, only bound those + with the loosest reference bounds.""" + gap = ( + ref_intermediate_ub[:, unstable_locs] + - ref_intermediate_lb[:, unstable_locs]).sum(dim=0) + indices = torch.argsort(gap, descending=True) + indices_selected = indices[:max_crown_size] + indices_selected, _ = torch.sort(indices_selected) + print(f'{len(indices_selected)}/{len(indices)} unstable neurons selected for CROWN') + return indices_selected + + +def get_unstable_locations( + self, ref_intermediate_lb, ref_intermediate_ub, conv=False, channel_only=False): + max_crown_size = self.bound_opts.get('max_crown_size', int(1e9)) + # For conv layer we only check the case where all neurons are active/inactive. + unstable_masks = torch.logical_and(ref_intermediate_lb < 0, ref_intermediate_ub > 0) + # For simplicity, merge unstable locations for all elements in this batch. TODO: use individual unstable mask. + # It has shape (H, W) indicating if a neuron is unstable/stable. + # TODO: so far we merge over the batch dimension to allow easier implementation. + if channel_only: + # Only keep channels with unstable neurons. Used for initializing alpha. + unstable_locs = unstable_masks.sum(dim=(0,2,3)).bool() + # Shape is consistent with linear layers: a list of unstable neuron channels (no batch dim). + unstable_idx = unstable_locs.nonzero().squeeze(1) + else: + if not conv and unstable_masks.ndim > 2: + # Flatten the conv layer shape. + unstable_masks = unstable_masks.view(unstable_masks.size(0), -1) + ref_intermediate_lb = ref_intermediate_lb.view(ref_intermediate_lb.size(0), -1) + ref_intermediate_ub = ref_intermediate_ub.view(ref_intermediate_ub.size(0), -1) + unstable_locs = unstable_masks.sum(dim=0).bool() + if conv: + # Now converting it to indices for these unstable nuerons. + # These are locations (i,j) of unstable neurons. + unstable_idx = unstable_locs.nonzero(as_tuple=True) + else: + unstable_idx = unstable_locs.nonzero().squeeze(1) + + unstable_size = get_unstable_size(unstable_idx) + if unstable_size > max_crown_size: + indices_seleted = select_unstable_idx( + ref_intermediate_lb, ref_intermediate_ub, unstable_locs, max_crown_size) + if isinstance(unstable_idx, tuple): + unstable_idx = tuple(u[indices_seleted] for u in unstable_idx) + else: + unstable_idx = unstable_idx[indices_seleted] + unstable_size = get_unstable_size(unstable_idx) + + return unstable_idx, unstable_size + + +def get_alpha_crown_start_nodes( + self, node, c=None, share_slopes=False, final_node_name=None): + # When use_full_conv_alpha is True, conv layers do not share alpha. + sparse_intermediate_bounds = self.bound_opts.get('sparse_intermediate_bounds', False) + use_full_conv_alpha_thresh = self.bound_opts.get('use_full_conv_alpha_thresh', 512) + + start_nodes = [] + for nj in self.backward_from[node.name]: # Pre-activation layers. + unstable_idx = None + use_sparse_conv = None + use_full_conv_alpha = self.bound_opts.get('use_full_conv_alpha', False) + if (sparse_intermediate_bounds and isinstance(node, BoundRelu) + and nj.name != final_node_name and not share_slopes): + # Create sparse optimization variables for intermediate neurons. + if ((isinstance(nj, BoundLinear) or isinstance(nj, BoundMatMul)) + and int(os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0): + # unstable_idx has shape [neuron_size_of_nj]. Batch dimension is reduced. + unstable_idx, _ = self.get_unstable_locations(nj.lower, nj.upper) + elif isinstance(nj, (BoundConv, BoundAdd, BoundSub, BoundBatchNormalization)) and nj.mode == 'patches': + if nj.name in node.patch_size: + # unstable_idx has shape [channel_size_of_nj]. Batch and spatial dimensions are reduced. + unstable_idx, _ = self.get_unstable_locations( + nj.lower, nj.upper, channel_only=not use_full_conv_alpha, conv=True) + use_sparse_conv = False # alpha is shared among channels. Sparse-spec alpha in hw dimension not used. + if use_full_conv_alpha and unstable_idx[0].size(0) > use_full_conv_alpha_thresh: + # Too many unstable neurons. Using shared alpha per channel. + unstable_idx, _ = self.get_unstable_locations( + nj.lower, nj.upper, channel_only=True, conv=True) + use_full_conv_alpha = False + else: + # matrix mode for conv layers. + # unstable_idx has shape [c_out * h_out * w_out]. Batch dimension is reduced. + unstable_idx, _ = self.get_unstable_locations(nj.lower, nj.upper) + use_sparse_conv = True # alpha is not shared among channels, and is sparse in spec dimension. + if nj.name == final_node_name: + size_final = self[final_node_name].output_shape[-1] if c is None else c.size(1) + start_nodes.append((final_node_name, size_final, None)) + continue + if share_slopes: + # all intermediate neurons from the same layer share the same set of slopes. + output_shape = 1 + elif isinstance(node, BoundOptimizableActivation) and node.patch_size and nj.name in node.patch_size: + # Patches mode. Use output channel size as the spec size. This still shares some alpha, but better than no sharing. + if use_full_conv_alpha: + # alphas not shared among channels, so the spec dim shape is c,h,w + # The patch size is [out_ch, batch, out_h, out_w, in_ch, H, W]. We use out_ch as the output shape. + output_shape = node.patch_size[nj.name][0], node.patch_size[nj.name][2], node.patch_size[nj.name][3] + else: + # The spec dim is c only, and is shared among h, w. + output_shape = node.patch_size[nj.name][0] + assert not sparse_intermediate_bounds or use_sparse_conv is False # Double check our assumption holds. If this fails, then we created wrong shapes for alpha. + else: + # Output is linear layer, or patch converted to matrix. + assert not sparse_intermediate_bounds or use_sparse_conv is not False # Double check our assumption holds. If this fails, then we created wrong shapes for alpha. + output_shape = nj.lower.shape[1:] # FIXME: for non-relu activations it's still expecting a prod. + start_nodes.append((nj.name, output_shape, unstable_idx)) + return start_nodes diff --git a/auto_LiRPA/beta_crown.py b/auto_LiRPA/beta_crown.py new file mode 100644 index 0000000..6be6dc0 --- /dev/null +++ b/auto_LiRPA/beta_crown.py @@ -0,0 +1,48 @@ +import torch + + +def beta_bias(self): + batch_size = len(self.relus[-1].split_beta) + batch = int(batch_size/2) + bias = torch.zeros((batch_size, 1), device=self.device) + for m in self.relus: + if not m.used or not m.perturbed: + continue + if m.split_beta_used: + bias[:batch] = bias[:batch] + m.split_bias*m.split_beta[:batch]*m.split_c[:batch] + bias[batch:] = bias[batch:] + m.split_bias*m.split_beta[batch:]*m.split_c[batch:] + if m.history_beta_used: + bias = bias + (m.new_history_bias*m.new_history_beta*m.new_history_c).sum(1, keepdim=True) + # No single node split here, because single node splits do not have bias. + return bias + + +def print_optimized_beta(self, relus, intermediate_beta_enabled=False): + masked_betas = [] + for model in relus: + masked_betas.append(model.masked_beta) + if model.history_beta_used: + print(f"{model.name} history beta", model.new_history_beta.squeeze()) + if model.split_beta_used: + print(f"{model.name} split beta:", model.split_beta.view(-1)) + print(f"{model.name} bias:", model.split_bias) + + +def save_best_intermediate_betas(self, relus, idx): + for layer in relus: + # The history split and current split is handled seperatedly. + if layer.history_beta_used: + # Each key in history_intermediate_betas for this layer is a dictionary, with all other pre-relu layers' names. + for k, v in layer.history_intermediate_betas.items(): + # This is a tensor with shape (batch, *intermediate_layer_shape, number_of_beta) + self.best_intermediate_betas[layer.name]['history'][k]["lb"][idx] = v["lb"][idx] + self.best_intermediate_betas[layer.name]['history'][k]["ub"][idx] = v["ub"][idx] + if layer.split_beta_used: + for k, v in layer.split_intermediate_betas.items(): + # This is a tensor with shape (batch, *intermediate_layer_shape, 1) + self.best_intermediate_betas[layer.name]['split'][k]["lb"][idx] = v["lb"][idx] + self.best_intermediate_betas[layer.name]['split'][k]["ub"][idx] = v["ub"][idx] + if layer.single_beta_used: + for k, v in layer.single_intermediate_betas.items(): + self.best_intermediate_betas[layer.name]['single'][k]["lb"][idx] = v["lb"][idx] + self.best_intermediate_betas[layer.name]['single'][k]["ub"][idx] = v["ub"][idx] \ No newline at end of file diff --git a/auto_LiRPA/bound_general.py b/auto_LiRPA/bound_general.py index 27ceeae..d02ebe0 100644 --- a/auto_LiRPA/bound_general.py +++ b/auto_LiRPA/bound_general.py @@ -1,11 +1,9 @@ -import time -import os import numpy as np -from collections import OrderedDict, deque, defaultdict +import warnings +from collections import OrderedDict, deque import torch -import torch.optim as optim -from torch.nn import DataParallel, Parameter, parameter +from torch.nn import Parameter from .bound_op_map import bound_op_map from .bound_ops import * @@ -13,12 +11,11 @@ from .parse_graph import parse_module from .perturbations import * from .utils import * -from .adam_element_lr import AdamElementLR +from .patches import Patches -import warnings -warnings.simplefilter("once") +warnings.simplefilter('once') class BoundedModule(nn.Module): """Bounded module with support for automatically computing bounds. @@ -26,93 +23,203 @@ class BoundedModule(nn.Module): Args: model (nn.Module): The original model to be wrapped by BoundedModule. - global_input (tuple): A dummy input to the original model. The shape of - the dummy input should be consistent with the actual input to the model + global_input (tuple): A dummy input to the original model. The shape of + the dummy input should be consistent with the actual input to the model except for the batch dimension. - bound_opts (dict): Options for bounds. See + bound_opts (dict): Options for bounds. See `Bound Options `_. - device (str or torch.device): Device of the bounded module. - If 'auto', the device will be automatically inferred from the device of + device (str or torch.device): Device of the bounded module. + If 'auto', the device will be automatically inferred from the device of parameters in the original model or the dummy input. - custom_ops (dict): A dictionary of custom operators. - The dictionary maps operator names to their corresponding bound classes + custom_ops (dict): A dictionary of custom operators. + The dictionary maps operator names to their corresponding bound classes (subclasses of `Bound`). """ - def __init__(self, model, global_input, bound_opts=None, auto_batch_dim=True, device='auto', - verbose=False, custom_ops={}): - super(BoundedModule, self).__init__() + def __init__(self, model, global_input, bound_opts=None, + device='auto', verbose=False, custom_ops=None): + super().__init__() if isinstance(model, BoundedModule): for key in model.__dict__.keys(): setattr(self, key, getattr(model, key)) return + + self.global_input = global_input + self.ori_training = model.training + self.check_incompatible_nodes(model) + if bound_opts is None: bound_opts = {} # Default options. - default_bound_opts = {'ibp_relative': False, 'conv_mode': 'patches', 'sparse_intermediate_bounds': True, 'sparse_conv_intermediate_bounds': True} + default_bound_opts = { + 'conv_mode': 'patches', + 'sparse_intermediate_bounds': True, + 'sparse_conv_intermediate_bounds': True, + 'sparse_intermediate_bounds_with_ibp': True, + 'sparse_features_alpha': True, + 'sparse_spec_alpha': True, + 'minimum_sparsity': 0.9, + 'enable_opt_interm_bounds': False, + 'crown_batch_size': np.inf, + 'forward_refinement': False, + 'dynamic_forward': False, + 'forward_max_dim': int(1e9), + # Do not share alpha for conv layers. + 'use_full_conv_alpha': True, + # Threshold for number of unstable neurons for each layer to disable + # use_full_conv_alpha. + 'use_full_conv_alpha_thresh': 512, + 'verbosity': 1 if verbose else 0, + } default_bound_opts.update(bound_opts) self.bound_opts = default_bound_opts self.verbose = verbose - self.custom_ops = custom_ops - self.auto_batch_dim = auto_batch_dim + self.custom_ops = custom_ops if custom_ops is not None else {} if device == 'auto': try: self.device = next(model.parameters()).device - except StopIteration: # Model has no parameters. We use the device of input tensor. + except StopIteration: + # Model has no parameters. We use the device of input tensor. self.device = global_input.device else: self.device = device - self.global_input = global_input - self.ibp_relative = self.bound_opts.get('ibp_relative', False) self.conv_mode = self.bound_opts.get('conv_mode', 'patches') - if auto_batch_dim: - # logger.warning('Using automatic batch dimension inferring, which may not be correct') - self.init_batch_size = -1 + # Cached IBP results which may be reused + self.ibp_lower, self.ibp_upper = None, None + + self.optimizable_activations = [] + self.relus = [] # save relu layers for convenience state_dict_copy = copy.deepcopy(model.state_dict()) object.__setattr__(self, 'ori_state_dict', state_dict_copy) model.to(self.device) - self.final_shape = model(*unpack_inputs(global_input, device=self.device)).shape + self.final_shape = model( + *unpack_inputs(global_input, device=self.device)).shape self.bound_opts.update({'final_shape': self.final_shape}) self._convert(model, global_input) self._mark_perturbed_nodes() # set the default values here - optimize_bound_args = {'ob_iteration': 20, 'ob_beta': False, 'ob_alpha': True, 'ob_alpha_share_slopes': False, - 'ob_opt_coeffs': False, 'ob_opt_bias': False, - 'ob_optimizer': 'adam', 'ob_verbose': 0, - 'ob_keep_best': True, 'ob_update_by_layer': True, 'ob_lr': 0.5, - 'ob_lr_beta': 0.05, 'ob_init': True, - 'ob_single_node_split': True, 'ob_lr_intermediate_beta': 0.1, - 'ob_lr_coeffs': 0.01, 'ob_intermediate_beta': False, 'ob_intermediate_refinement_layers': [-1], - 'ob_loss_reduction_func': reduction_sum, - 'ob_stop_criterion_func': lambda x: False, - 'ob_input_grad': False, - 'ob_lr_decay': 0.98 } + optimize_bound_args = { + 'enable_alpha_crown': True, # Enable optimization of alpha. + 'enable_beta_crown': False, # Enable beta split constraint. + 'iteration': 20, # Number of alpha/beta optimization iterations. + # Share some alpha variables to save memory at the cost of slightly + # looser bounds. + 'use_shared_alpha': False, + # Optimize coeffs during intermediate_refinement. + 'opt_coeffs': False, + # Optimize constraint bias during intermediate_refinement. + 'opt_bias': False, + # Optimizer used for alpha and beta optimization. + 'optimizer': 'adam', + # Save best results of alpha/beta/bounds during optimization. + 'keep_best': True, + # Only optimize bounds of last layer during alpha/beta CROWN. + 'fix_intermediate_layer_bounds': True, + # Learning rate for the optimizable parameter alpha in alpha-CROWN. + 'lr_alpha': 0.5, + # Learning rate for the optimizable parameter beta in beta-CROWN. + 'lr_beta': 0.05, + 'lr_cut_beta': 5e-3, # Learning rate for optimizing cut betas. + # Initial alpha variables by calling CROWN once. + 'init_alpha': True, + # Only split single nodes in branch and bound. + 'single_node_split': True, + # Learning rate for intermediate layer beta for refinement. + 'lr_intermediate_beta': 0.1, + 'lr_coeffs': 0.01, # Learning rate for coeffs for refinement + # Optimize constraint bias in compute bounds. + 'intermediate_beta': False, + # Layers to be refined, separated by commas. + # -1 means preactivation before last relu. + 'intermediate_refinement_layers': [-1], + # When batch size is not 1, this reduction function is applied to + # reduce the bounds into a scalar. + 'loss_reduction_func': reduction_sum, + # Criteria function of early stop. + 'stop_criterion_func': lambda x: False, + # Learning rate decay factor during bounds optimization. + 'lr_decay': 0.98, + # Number of iterations that we will start considering early stop + # if tracking no improvement. + 'early_stop_patience': 10, + # Start to save optimized best bounds + # when current_iteration > int(iteration*start_save_best) + 'start_save_best': 0.5, + # Use double fp (float64) at the last iteration in alpha/beta CROWN. + 'use_float64_in_last_iteration': False, + # Prune verified domain within iteration. + 'pruning_in_iteration': False, + # Percentage of the minimum domains that can apply pruning. + 'pruning_in_iteration_threshold': 0.2, + # For specification that will output multiple bounds for one + # property, we use this function to prune them. + 'multi_spec_keep_func': lambda x: True + } + # change by bound_opts - optimize_bound_args.update(self.bound_opts.get('optimize_bound_args', {})) + optimize_bound_args.update( + self.bound_opts.get('optimize_bound_args', {})) self.bound_opts.update({'optimize_bound_args': optimize_bound_args}) self.next_split_hint = [] # Split hints, used in beta optimization. - self.relus = [] # save relu layers for convenience - for l in self._modules.values(): - if isinstance(l, BoundRelu): - self.relus.append(l) - self.optimizable_activations = [] - for l in self._modules.values(): - if isinstance(l, BoundOptimizableActivation): - self.optimizable_activations.append(l) - - # Beta values for all intermediate bounds. Set to None (not used) by default. + # Beta values for all intermediate bounds. + # Set to None (not used) by default. self.best_intermediate_betas = None # Initialization value for intermediate betas. self.init_intermediate_betas = None + # whether using cut + self.cut_used = False + # a placeholder for cut timestamp, which would be a non-positive int + self.cut_timestamp = -1 + + # List of operators. When we are computing intermediate bounds for these + # ops, we simply use IBP to propagate bounds from its input nodes, + # instead of CROWN. + self.ibp_intermediate = [BoundRelu, BoundNeg, BoundTranspose] + + # a placeholder to save the latest samplewise mask for + # pruning-in-iteration optimization + self.last_update_preserve_mask = None + + @property + def perturbed_optimizable_activations(self): + return [n for n in self.optimizable_activations if n.perturbed] + + + def check_incompatible_nodes(self, model): + """Check whether the model has incompatible nodes that the conversion + may be inaccurate""" + node_types = [type(m) for m in list(model.modules())] + + if (torch.nn.Dropout in node_types + and torch.nn.BatchNorm1d in node_types + and self.global_input.shape[0] == 1): + print('We cannot support torch.nn.Dropout and torch.nn.BatchNorm1d ' + 'at the same time!') + print('Suggest to use another dummy input which has batch size ' + 'larger than 1 and set model to train() mode.') + return + + if not self.ori_training and torch.nn.Dropout in node_types: + print('Dropout operation CANNOT be parsed during conversion when ' + 'the model is in eval() mode!') + print('Set model to train() mode!') + self.ori_training = True + + if self.ori_training and torch.nn.BatchNorm1d in node_types: + print('BatchNorm1d may raise error during conversion when the model' + ' is in train() mode!') + print('Set model to eval() mode!') + self.ori_training = False - """Some operations are non-deterministic and deterministic mode will fail. So we temporary disable it.""" def non_deter_wrapper(self, op, *args, **kwargs): + """Some operations are non-deterministic and deterministic mode will + fail. So we temporary disable it.""" if self.bound_opts.get('deterministic', False): torch.use_deterministic_algorithms(False) ret = op(*args, **kwargs) @@ -128,22 +235,38 @@ def non_deter_index_select(self, *args, **kwargs): def set_bound_opts(self, new_opts): for k, v in new_opts.items(): - assert v is not dict, 'only support change optimize_bound_args' - self.bound_opts[k].update(v) + # assert v is not dict, 'only support change optimize_bound_args' + if type(v) == dict: + self.bound_opts[k].update(v) + else: + self.bound_opts[k] = v - def __call__(self, *input, **kwargs): + @staticmethod + def _get_A_norm(A): + if not isinstance(A, (list, tuple)): + A = (A, ) + norms = [] + for aa in A: + if aa is not None: + if isinstance(aa, Patches): + aa = aa.patches + norms.append(aa.abs().sum().item()) + else: + norms.append(None) + return norms - if "method_opt" in kwargs: - opt = kwargs["method_opt"] - kwargs.pop("method_opt") + def __call__(self, *input, **kwargs): + if 'method_opt' in kwargs: + opt = kwargs['method_opt'] + kwargs.pop('method_opt') else: - opt = "forward" + opt = 'forward' for kwarg in [ 'disable_multi_gpu', 'no_replicas', 'get_property', 'node_class', 'att_name']: if kwarg in kwargs: kwargs.pop(kwarg) - if opt == "compute_bounds": + if opt == 'compute_bounds': return self.compute_bounds(**kwargs) else: return self.forward(*input, **kwargs) @@ -160,28 +283,29 @@ def register_parameter(self, name, param): """ if '_parameters' not in self.__dict__: raise AttributeError( - "cannot assign parameter before Module.__init__() call") + 'cannot assign parameter before Module.__init__() call') elif not isinstance(name, torch._six.string_classes): - raise TypeError("parameter name should be a string. " - "Got {}".format(torch.typename(name))) + raise TypeError('parameter name should be a string. ' + f'Got {torch.typename(name)}') elif name == '': - raise KeyError("parameter name can't be empty string \"\"") + raise KeyError('parameter name can\'t be empty string') elif hasattr(self, name) and name not in self._parameters: - raise KeyError("attribute '{}' already exists".format(name)) + raise KeyError(f'attribute "{name}" already exists') if param is None: self._parameters[name] = None elif not isinstance(param, Parameter): - raise TypeError("cannot assign '{}' object to parameter '{}' " - "(torch.nn.Parameter or None required)" - .format(torch.typename(param), name)) + raise TypeError( + f'cannot assign "{torch.typename(param)}" object to ' + f'parameter "{name}" ' + '(torch.nn.Parameter or None required)') elif param.grad_fn: raise ValueError( - "Cannot assign non-leaf Tensor to parameter '{0}'. Model " - "parameters must be created explicitly. To express '{0}' " - "as a function of another Tensor, compute the value in " - "the forward() method.".format(name)) + f'Cannot assign non-leaf Tensor to parameter "{name}". Model ' + 'parameters must be created explicitly. To express "{name}" ' + 'as a function of another Tensor, compute the value in ' + 'the forward() method.') else: self._parameters[name] = param @@ -191,12 +315,13 @@ def load_state_dict(self, state_dict, strict=False): for k, v in state_dict.items(): if k in self.node_name_map: new_dict[self.node_name_map[k]] = v - return super(BoundedModule, self).load_state_dict(new_dict, strict=strict) + return super().load_state_dict(new_dict, strict=strict) def _named_members(self, get_members_fn, prefix='', recurse=True): r"""Helper method for yielding various names + members of modules.""" memo = set() - modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] + modules = self.named_modules(prefix=prefix) if recurse else [ + (prefix, self)] for module_prefix, module in modules: members = get_members_fn(module) for k, v in members: @@ -219,75 +344,90 @@ def eval(self): for node in self._modules.values(): node.eval() - def forward(self, *x, final_node_name=None, clear_forward_only=False): + def to(self, *args, **kwargs): + # Moves and/or casts some attributes except pytorch will do by default. + for node in self._modules.values(): + for attr in ['lower', 'upper', 'forward_value', 'd', 'lA',]: + if hasattr(node, attr): + this_attr = getattr(node, attr) + if isinstance(this_attr, torch.Tensor): + # print(node, attr) + this_attr = this_attr.to(*args, **kwargs) + setattr(node, attr, this_attr) + + if hasattr(node, 'interval'): + # construct new interval + this_attr = getattr(node, 'interval') + setattr(node, 'interval', (this_attr[0].to( + *args, **kwargs), this_attr[1].to(*args, **kwargs))) + + return super().to(*args, **kwargs) + + def __getitem__(self, name): + return self._modules[name] + + def final_node(self): + return self[self.final_name] + + def get_forward_value(self, node): + """ Recursively get `forward_value` for `node` and its parent nodes""" + if getattr(node, 'forward_value', None) is not None: + return node.forward_value + inputs = [self.get_forward_value(inp) for inp in node.inputs] + for inp in node.inputs: + node.from_input = node.from_input or inp.from_input + node.input_shape = inputs[0].shape if len(inputs) > 0 else None + fv = node.forward(*inputs) + if isinstance(fv, (torch.Size, tuple)): + fv = torch.tensor(fv, device=self.device) + node.forward_value = fv + node.output_shape = fv.shape + # In most cases, the batch dimension is just the first dimension + # if the node depends on input. Otherwise if the node doesn't + # depend on input, there is no batch dimension. + node.batch_dim = 0 if node.from_input else -1 + # Unperturbed node but it is not a root node. + # Save forward_value to value. (Can be used in forward bounds.) + if not node.from_input and len(node.inputs) > 0: + node.value = node.forward_value + return fv + + def forward(self, *x, final_node_name=None, clear_forward_only=False, + reset_perturbed_nodes=True): r"""Standard forward computation for the network. Args: x (tuple or None): Input to the model. - final_node_name (str, optional): The name of the final node in the model. The value - on the corresponding node will be returned. + final_node_name (str, optional): The name of the final node in the + model. The value on the corresponding node will be returned. + + clear_forward_only (bool, default `False`): Whether only standard + forward values stored on the nodes should be cleared. If `True`, + only standard forward values stored on the nodes will be cleared. + Otherwise, bound information on the nodes will also be cleared. - clear_forward_only (bool, default `False`): Whether only standard forward values stored - on the nodes should be cleared. If `True`, only standard forward values stored on the - nodes will be cleared. Otherwise, bound information on the nodes will also be cleared. + reset_perturbed_nodes (bool, default `True`): Mark all perturbed + nodes with input perturbations. When set to `True`, it may + accidentally clear all .perturbed properties for intermediate + nodes. Returns: - output: The output of the model, or if `final_node_name` is not `None`, return the - value on the corresponding node instead. + output: The output of the model, or if `final_node_name` is not + `None`, return the value on the corresponding node instead. """ - - self._set_input(*x, clear_forward_only=clear_forward_only) - - degree_in = {} - queue = deque() - for key in self._modules.keys(): - l = self._modules[key] - degree_in[l.name] = len(l.input_name) - if degree_in[l.name] == 0: - queue.append(l) - forward_values = {} - - final_output = None - while len(queue) > 0: - l = queue.popleft() - inp = [forward_values[l_pre] for l_pre in l.input_name] - for l_pre in l.inputs: - l.from_input = l.from_input or l_pre.from_input - fv = l.forward(*inp) - if isinstance(fv, torch.Size) or isinstance(fv, tuple): - fv = torch.tensor(fv, device=self.device) - object.__setattr__(l, 'forward_value', fv) - # infer batch dimension - if not hasattr(l, 'batch_dim'): - inp_batch_dim = [l_pre.batch_dim for l_pre in l.inputs] - try: - l.batch_dim = l.infer_batch_dim(self.init_batch_size, *inp_batch_dim) - except: - raise Exception( - 'Fail to infer the batch dimension of ({})[{}]: forward_value shape {}, input batch dimensions {}'.format( - l, l.name, l.forward_value.shape, inp_batch_dim)) - forward_values[l.name] = l.forward_value - - # Unperturbed node but it is not a root node. Save forward_value to value. - # (Can be used in forward bounds.) - if not l.from_input and len(l.inputs) > 0: - l.value = l.forward_value - - for l_next in l.output_name: - degree_in[l_next] -= 1 - if degree_in[l_next] == 0: # all inputs of this node have already set - queue.append(self._modules[l_next]) - + self._set_input(*x, clear_forward_only=clear_forward_only, + reset_perturbed_nodes=reset_perturbed_nodes) if final_node_name: - return forward_values[final_node_name] + return self.get_forward_value(self[final_node_name]) else: - out = deque([forward_values[n] for n in self.output_name]) + out = deque([self.get_forward_value(self[n]) + for n in self.output_name]) def _fill_template(template): if template is None: return out.popleft() - elif isinstance(template, list) or isinstance(template, tuple): + elif isinstance(template, (list, tuple)): res = [] for t in template: res.append(_fill_template(t)) @@ -302,84 +442,111 @@ def _fill_template(template): return _fill_template(self.output_template) - """Mark the graph nodes and determine which nodes need perturbation.""" def _mark_perturbed_nodes(self): + """Mark the graph nodes and determine which nodes need perturbation.""" degree_in = {} queue = deque() # Initially the queue contains all "root" nodes. for key in self._modules.keys(): - l = self._modules[key] - degree_in[l.name] = len(l.input_name) + l = self[key] + degree_in[l.name] = len(l.inputs) if degree_in[l.name] == 0: queue.append(l) # in_degree ==0 -> root node while len(queue) > 0: l = queue.popleft() - # Obtain all output node, and add the output nodes to the queue if all its input nodes have been visited. - # the initial "perturbed" property is set in BoundInput or BoundParams object, depending on ptb. + # Obtain all output node, and add the output nodes to the queue if + # all its input nodes have been visited. + # The initial "perturbed" property is set in BoundInput or + # BoundParams object, depending on ptb. for name_next in l.output_name: - node_next = self._modules[name_next] + node_next = self[name_next] if isinstance(l, BoundShape): - # Some nodes like Shape, even connected, do not really propagate bounds. + # Some nodes like Shape, even connected, + # do not really propagate bounds. # TODO: make this a property of node? pass else: - # The next node is perturbed if it is already perturbed, or this node is perturbed. + # The next node is perturbed if it is already perturbed, + # or this node is perturbed. node_next.perturbed = node_next.perturbed or l.perturbed degree_in[name_next] -= 1 - if degree_in[name_next] == 0: # all inputs of this node have been visited, now put it in queue. + # all inputs of this node have been visited, + # now put it in queue. + if degree_in[name_next] == 0: queue.append(node_next) return - def _clear_and_set_new(self, new_interval, clear_forward_only=False): + def _clear_and_set_new( + self, intermediate_layer_bounds, clear_forward_only=False, + reset_perturbed_nodes=True): for l in self._modules.values(): if hasattr(l, 'linear'): if isinstance(l.linear, tuple): for item in l.linear: - del (item) + del item delattr(l, 'linear') + if hasattr(l, 'patch_size'): + l.patch_size = {} + if clear_forward_only: if hasattr(l, 'forward_value'): - delattr(l, 'forward_value') + delattr(l, 'forward_value') else: - for attr in ['lower', 'upper', 'interval', 'forward_value', 'd', 'lA', 'lower_d']: + for attr in [ + 'lower', 'upper', 'interval', 'forward_value', 'd', + 'lA', 'lower_d']: if hasattr(l, attr): delattr(l, attr) - for attr in ['zero_backward_coeffs_l', 'zero_backward_coeffs_u', 'zero_lA_mtx', 'zero_uA_mtx']: + for attr in [ + 'zero_backward_coeffs_l', 'zero_backward_coeffs_u', + 'zero_lA_mtx', 'zero_uA_mtx']: setattr(l, attr, False) # Given an interval here to make IBP/CROWN start from this node - if new_interval is not None and l.name in new_interval.keys(): - l.interval = tuple(new_interval[l.name][:2]) - l.lower = new_interval[l.name][0] - l.upper = new_interval[l.name][1] + if (intermediate_layer_bounds is not None + and l.name in intermediate_layer_bounds.keys()): + l.interval = tuple(intermediate_layer_bounds[l.name][:2]) + l.lower = intermediate_layer_bounds[l.name][0] + l.upper = intermediate_layer_bounds[l.name][1] + if l.lower is not None: + l.lower = l.lower.detach().requires_grad_(False) + if l.upper is not None: + l.upper = l.upper.detach().requires_grad_(False) # Mark all nodes as non-perturbed except for weights. - if not hasattr(l, 'perturbation') or l.perturbation is None: - l.perturbed = False - - def _set_input(self, *x, new_interval=None, clear_forward_only=False): - self._clear_and_set_new(new_interval=new_interval, clear_forward_only=clear_forward_only) + if reset_perturbed_nodes: + if not hasattr(l, 'perturbation') or l.perturbation is None: + l.perturbed = False + + # Clear operator-specific attributes + l.clear() + + def _set_input( + self, *x, intermediate_layer_bounds=None, + clear_forward_only=False, reset_perturbed_nodes=True): + self._clear_and_set_new( + intermediate_layer_bounds=intermediate_layer_bounds, + clear_forward_only=clear_forward_only, + reset_perturbed_nodes=reset_perturbed_nodes) inputs_unpacked = unpack_inputs(x) for name, index in zip(self.input_name, self.input_index): - node = self._modules[name] + if index is None: + continue + node = self[name] node.value = inputs_unpacked[index] if isinstance(node.value, (BoundedTensor, BoundedParameter)): node.perturbation = node.value.ptb else: node.perturbation = None # Mark all perturbed nodes. - self._mark_perturbed_nodes() - if self.init_batch_size == -1: - # Automatic batch dimension inferring: get the batch size from - # the first dimension of the first input tensor. - self.init_batch_size = inputs_unpacked[0].shape[0] + if reset_perturbed_nodes: + self._mark_perturbed_nodes() def _get_node_input(self, nodesOP, nodesIn, node): ret = [] ori_names = [] for i in range(len(node.inputs)): - found = False for op in nodesOP: if op.name == node.inputs[i]: ret.append(op.bound_node) @@ -392,84 +559,108 @@ def _get_node_input(self, nodesOP, nodesIn, node): ori_names.append(io.ori_name) break if len(ret) <= i: - raise ValueError('cannot find inputs of node: {}'.format(node.name)) - return ret, ori_names + raise ValueError(f'cannot find inputs of node: {node.name}') + return ret - # move all tensors in the object to a specified device - def _to(self, obj, device): - if isinstance(obj, torch.Tensor): - return obj.to(device) + def _to(self, obj, dest, inplace=False): + """ Move all tensors in the object to a specified dest + (device or dtype). The inplace=True option is available for dict.""" + if obj is None: + return obj + elif isinstance(obj, torch.Tensor): + return obj.to(dest) + elif isinstance(obj, Patches): + return obj.patches.to(dest) elif isinstance(obj, tuple): - return tuple([self._to(item, device) for item in obj]) + return tuple([self._to(item, dest) for item in obj]) elif isinstance(obj, list): - return list([self._to(item, device) for item in obj]) + return list([self._to(item, dest) for item in obj]) elif isinstance(obj, dict): - res = {} - for key in obj: - res[key] = self._to(obj[key], device) - return res + if inplace: + for k, v in obj.items(): + obj[k] = self._to(v, dest, inplace=True) + return obj + else: + return {k: self._to(v, dest) for k, v in obj.items()} else: raise NotImplementedError(type(obj)) def _convert_nodes(self, model, global_input): + r""" + Returns: + nodesOP (list): List of operator nodes + nodesIn (list): List of input nodes + nodesOut (list): List of output nodes + template (object): Template to specify the output format + """ global_input_cpu = self._to(global_input, 'cpu') - model.train() + if self.ori_training: + model.train() + else: + model.eval() model.to('cpu') - nodesOP, nodesIn, nodesOut, template = parse_module(model, global_input_cpu) + nodesOP, nodesIn, nodesOut, template = parse_module( + model, global_input_cpu) model.to(self.device) for i in range(0, len(nodesIn)): if nodesIn[i].param is not None: - nodesIn[i] = nodesIn[i]._replace(param=nodesIn[i].param.to(self.device)) + nodesIn[i] = nodesIn[i]._replace( + param=nodesIn[i].param.to(self.device)) global_input_unpacked = unpack_inputs(global_input) # Convert input nodes and parameters. for i, n in enumerate(nodesIn): if n.input_index is not None: nodesIn[i] = nodesIn[i]._replace(bound_node=BoundInput( - nodesIn[i].inputs, nodesIn[i].name, nodesIn[i].ori_name, + ori_name=nodesIn[i].ori_name, value=global_input_unpacked[nodesIn[i].input_index], - perturbation=nodesIn[i].perturbation)) + perturbation=nodesIn[i].perturbation, + input_index=n.input_index)) else: - bound_class = BoundParams if isinstance(nodesIn[i].param, nn.Parameter) else BoundBuffers + bound_class = BoundParams if isinstance( + nodesIn[i].param, nn.Parameter) else BoundBuffers nodesIn[i] = nodesIn[i]._replace(bound_node=bound_class( - nodesIn[i].inputs, nodesIn[i].name, nodesIn[i].ori_name, - value=nodesIn[i].param, perturbation=nodesIn[i].perturbation)) + ori_name=nodesIn[i].ori_name, + value=nodesIn[i].param, + perturbation=nodesIn[i].perturbation)) unsupported_ops = [] # Convert other operation nodes. for n in range(len(nodesOP)): attr = nodesOP[n].attr - inputs, ori_names = self._get_node_input(nodesOP, nodesIn, nodesOP[n]) - + inputs = self._get_node_input(nodesOP, nodesIn, nodesOP[n]) try: if nodesOP[n].op in self.custom_ops: op = self.custom_ops[nodesOP[n].op] elif nodesOP[n].op in bound_op_map: op = bound_op_map[nodesOP[n].op] elif nodesOP[n].op.startswith('aten::ATen'): - op = eval('BoundATen{}'.format(attr['operator'].capitalize())) + op = globals()[f'BoundATen{attr["operator"].capitalize()}'] elif nodesOP[n].op.startswith('onnx::'): - op = eval('Bound{}'.format(nodesOP[n].op[6:])) + op = globals()[f'Bound{nodesOP[n].op[6:]}'] else: raise KeyError except (NameError, KeyError): unsupported_ops.append(nodesOP[n]) - logger.error('The node has an unsupported operation: {}'.format(nodesOP[n])) + logger.error('The node has an unsupported operation: %s', + nodesOP[n]) continue - if nodesOP[n].op == 'onnx::BatchNormalization': - # BatchNormalization node needs model.training flag to set running mean and vars - # set training=False to avoid wrongly updating running mean/vars during bound wrapper - nodesOP[n] = nodesOP[n]._replace( - bound_node=op( - nodesOP[n].inputs, nodesOP[n].name, None, attr, - inputs, nodesOP[n].output_index, self.bound_opts, self.device, False)) + attr['device'] = self.device + + # FIXME generalize + if (nodesOP[n].op == 'onnx::BatchNormalization' + or getattr(op, 'TRAINING_FLAG', False)): + # BatchNormalization node needs model.training flag to set + # running mean and vars set training=False to avoid wrongly + # updating running mean/vars during bound wrapper + nodesOP[n] = nodesOP[n]._replace(bound_node=op( + attr, inputs, nodesOP[n].output_index, self.bound_opts, + False)) else: - nodesOP[n] = nodesOP[n]._replace( - bound_node=op( - nodesOP[n].inputs, nodesOP[n].name, None, attr, - inputs, nodesOP[n].output_index, self.bound_opts, self.device)) + nodesOP[n] = nodesOP[n]._replace(bound_node=op( + attr, inputs, nodesOP[n].output_index, self.bound_opts)) if unsupported_ops: logger.error('Unsupported operations:') @@ -477,51 +668,83 @@ def _convert_nodes(self, model, global_input): logger.error(f'Name: {n.op}, Attr: {n.attr}') raise NotImplementedError('There are unsupported operations') + for node in nodesIn + nodesOP: + node.bound_node.input_name = node.inputs + node.bound_node.name = node.name + + nodes_dict = {} + for node in nodesOP + nodesIn: + nodes_dict[node.name] = node.bound_node + nodesOP = [n.bound_node for n in nodesOP] + nodesIn = [n.bound_node for n in nodesIn] + nodesOut = [nodes_dict[n] for n in nodesOut] + return nodesOP, nodesIn, nodesOut, template def _build_graph(self, nodesOP, nodesIn, nodesOut, template): - nodes = [] - for node in nodesOP + nodesIn: - assert (node.bound_node is not None) - nodes.append(node.bound_node) # We were assuming that the original model had only one output node. - # When there are multiple output nodes, this seems to be the first output element. - # In this case, we are assuming that we aim to compute the bounds for the first - # output element by default. - self.final_name = nodesOut[0] + # When there are multiple output nodes, this seems to be the first + # output element. In this case, we are assuming that we aim to compute + # the bounds for the first output element by default. + self.final_name = nodesOut[0].name self.input_name, self.input_index, self.root_name = [], [], [] - for node in nodesIn: - self.root_name.append(node.name) - if node.input_index is not None: - self.input_name.append(node.name) - self.input_index.append(node.input_index) - self.output_name = nodesOut + self.output_name = [n.name for n in nodesOut] self.output_template = template - for l in nodes: - self._modules[l.name] = l - l.output_name = [] - if isinstance(l.input_name, str): - l.input_name = [l.input_name] - for l in nodes: - for l_pre in l.input_name: - self._modules[l_pre].output_name.append(l.name) - for l in nodes: - if self.conv_mode != 'patches' and len(l.input_name) == 0: - if not l.name in self.root_name: - # Add independent nodes that do not appear in `nodesIn`. - # Note that these nodes are added in the last, since - # order matters in the current implementation because - # `root[0]` is used in some places. - self.root_name.append(l.name) + for node in nodesIn: + self.add_input_node(node, index=node.input_index) + self.add_nodes(nodesOP) + if self.conv_mode == 'patches': + self.root_name = [node.name for node in nodesIn] + + # Make sure the nodes already have `name` and `input_name` + def add_nodes(self, nodes): + nodes = [(node if isinstance(node, Bound) else node.bound_node) + for node in nodes] + for node in nodes: + self._modules[node.name] = node + node.output_name = [] + if not hasattr(node, 'input_name'): + node.input_name = [] + if isinstance(node.input_name, str): + node.input_name = [node.input_name] + if len(node.inputs) == 0: + self.root_name.append(node.name) + for node in nodes: + for l_pre in node.inputs: + l_pre.output_name.append(node.name) + for node in nodes: + if isinstance(node, BoundOptimizableActivation): + self.optimizable_activations.append(node) + if isinstance(node, BoundRelu): + self.relus.append(node) + + def add_input_node(self, node, index=None): + self.add_nodes([node]) + self.input_name.append(node.name) + # default value for input_index + if index == 'auto': + index = max([0] + [(i + 1) + for i in self.input_index if i is not None]) + self.input_index.append(index) + + def rename_nodes(self, nodesOP, nodesIn, rename_dict): + def rename(node): + node.name = rename_dict[node.name] + node.input_name = [ + rename_dict[name] for name in node.input_name] + return node + for i in range(len(nodesOP)): + nodesOP[i] = rename(nodesOP[i]) + for i in range(len(nodesIn)): + nodesIn[i] = rename(nodesIn[i]) def _split_complex(self, nodesOP, nodesIn): finished = True for n in range(len(nodesOP)): - if hasattr(nodesOP[n].bound_node, 'complex') and \ - nodesOP[n].bound_node.complex: + if hasattr(nodesOP[n], 'complex') and nodesOP[n].complex: finished = False - _nodesOP, _nodesIn, _nodesOut, _template = self._convert_nodes( - nodesOP[n].bound_node.model, nodesOP[n].bound_node.input) + _nodesOP, _nodesIn, _nodesOut, _ = self._convert_nodes( + nodesOP[n].model, nodesOP[n].input) # assuming each supported complex operation only has one output assert len(_nodesOut) == 1 @@ -529,37 +752,28 @@ def _split_complex(self, nodesOP, nodesIn): rename_dict = {} for node in _nodesOP + _nodesIn: rename_dict[node.name] = name_base + node.name - num_inputs = len(nodesOP[n].bound_node.input) + num_inputs = len(nodesOP[n].inputs) for i in range(num_inputs): - rename_dict[_nodesIn[i].name] = nodesOP[n].inputs[i] + rename_dict[_nodesIn[i].name] = nodesOP[n].input_name[i] rename_dict[_nodesOP[-1].name] = nodesOP[n].name - def rename(node): - node.bound_node.name = rename_dict[node.bound_node.name] - node.bound_node.input_name = [ - rename_dict[name] for name in node.bound_node.input_name] - node = node._replace( - name=rename_dict[node.name], - inputs=node.bound_node.input_name) - return node - - for i in range(len(_nodesOP)): - _nodesOP[i] = rename(_nodesOP[i]) - for i in range(len(_nodesIn)): - _nodesIn[i] = rename(_nodesIn[i]) + self.rename_nodes(_nodesOP, _nodesIn, rename_dict) + output_name = _nodesOP[-1].name - # Any input node of some node within the complex node should be - # replaced with the corresponding input node of the complex node. + # Any input node of some node within the complex node should be + # replaced with the complex node's corresponding input node. for node in _nodesOP: - for i in range(len(node.bound_node.inputs)): - if node.bound_node.input_name[i] in nodesOP[n].inputs: - index = nodesOP[n].inputs.index(node.bound_node.input_name[i]) - node.bound_node.inputs[i] = nodesOP[n].bound_node.inputs[index] - # For any output node of this complex node, modify its input node + for i in range(len(node.inputs)): + if node.input_name[i] in nodesOP[n].input_name: + index = nodesOP[n].input_name.index( + node.input_name[i]) + node.inputs[i] = nodesOP[n].inputs[index] + # For any output node of this complex node, + # modify its input node. for node in nodesOP: - if output_name in node.bound_node.input_name: - index = node.bound_node.input_name.index(output_name) - node.bound_node.inputs[index] = _nodesOP[-1].bound_node + if output_name in node.input_name: + index = node.input_name.index(output_name) + node.inputs[index] = _nodesOP[-1] nodesOP = nodesOP[:n] + _nodesOP + nodesOP[(n + 1):] nodesIn = nodesIn + _nodesIn[num_inputs:] @@ -568,19 +782,21 @@ def rename(node): return nodesOP, nodesIn, finished - """build a dict with {ori_name: name, name: ori_name}""" - def _get_node_name_map(self, ): + def _get_node_name_map(self): + """Build a dict with {ori_name: name, name: ori_name}""" self.node_name_map = {} for node in self._modules.values(): if isinstance(node, BoundInput) or isinstance(node, BoundParams): for p in list(node.named_parameters()): if node.ori_name not in self.node_name_map: - self.node_name_map[node.ori_name] = node.name + '.' + p[0] - self.node_name_map[node.name + '.' + p[0]] = node.ori_name + name = f'{node.name}.{p[0]}' + self.node_name_map[node.ori_name] = name + self.node_name_map[name] = node.ori_name for p in list(node.named_buffers()): if node.ori_name not in self.node_name_map: - self.node_name_map[node.ori_name] = node.name + '.' + p[0] - self.node_name_map[node.name + '.' + p[0]] = node.ori_name + name = f'{node.name}.{p[0]}' + self.node_name_map[node.ori_name] = name + self.node_name_map[name] = node.ori_name # convert a Pytorch model to a model with bounds def _convert(self, model, global_input): @@ -591,7 +807,8 @@ def _convert(self, model, global_input): global_input = (global_input,) self.num_global_inputs = len(global_input) - nodesOP, nodesIn, nodesOut, template = self._convert_nodes(model, global_input) + nodesOP, nodesIn, nodesOut, template = self._convert_nodes( + model, global_input) global_input = self._to(global_input, self.device) while True: @@ -603,468 +820,314 @@ def _convert(self, model, global_input): self._get_node_name_map() - # load self.ori_state_dict again to avoid the running means/vars changed during forward() + # Load self.ori_state_dict again to avoid the running means/vars changed + # during forward(). self.load_state_dict(self.ori_state_dict) - model.load_state_dict(self.ori_state_dict) + if self.ori_training: + model.load_state_dict(self.ori_state_dict) delattr(self, 'ori_state_dict') # The final node used in the last time calling `compute_bounds` self.last_final_node = None - - logger.debug('NodesOP:') - for node in nodesOP: - logger.debug('{}'.format(node._replace(param=None))) - logger.debug('NodesIn') - for node in nodesIn: - logger.debug('{}'.format(node._replace(param=None))) + self.used_nodes = [] if self.verbose: logger.info('Model converted to support bounds') - def init_slope(self, x, share_slopes=False, method='backward', c=None, bound_lower=True, bound_upper=True): - if method != 'forward': - assert isinstance(x, tuple) - assert method == 'backward' - x = x[0] + def check_prior_bounds(self, node): + if node.prior_checked or not (node.used and node.perturbed): + return - for node in self.optimizable_activations: - # initialize the parameters - node.opt_init() + for n in node.inputs: + self.check_prior_bounds(n) - with torch.no_grad(): - if method == 'forward': - l, u = self.compute_bounds(x=(x,), method='forward', bound_lower=bound_lower, bound_upper=bound_upper) - else: - l, u = self.compute_bounds(x=(x,), IBP=False, C=c, method='backward', return_A=False, bound_lower=bound_lower, bound_upper=bound_upper) + if getattr(node, 'nonlinear', False): + for n in node.inputs: + self.compute_intermediate_bounds(n, prior_checked=True) + for i in getattr(node, 'requires_input_bounds', []): + self.compute_intermediate_bounds( + node.inputs[i], prior_checked=True) - init_intermediate_bounds = {} - for node in self.optimizable_activations: - if method == 'forward': - assert not '_forward' in self._modules.keys(), '_forward is a reserved node name' - assert isinstance(node, BoundRelu), 'Only ReLU is supported for optimizing forward bounds' - start_nodes = [ ('_forward', 1) ] - else: - start_nodes = [] - for nj in self.backward_from[node.name]: - if nj.name == self.final_name: - size_final = self.final_shape[-1] if c is None else c.size(1) - start_nodes.append((self.final_name, size_final)) - continue - if share_slopes: - # all intermediate neurons from the same layer share the same set of slopes. - output_shape = 1 - elif isinstance(node, BoundRelu) and node.patch_size and nj.name in node.patch_size: - # Patches mode. Use output channel size as the spec size. This still shares some alpha, but better than no sharing. - # The patch size is [out_ch, batch, out_h, out_w, in_ch, H, W]. We use out_ch as the output shape. - output_shape = node.patch_size[nj.name][0] - # print(f'node {nj.name} {nj} use patch size {output_shape} patches size {node.patch_size[nj.name][0]}') - else: - output_shape = prod(nj.lower.shape[1:]) - # print(f'node {nj.name} {nj} use regular size {output_shape}') - start_nodes.append((nj.name, output_shape)) - node.init_opt_parameters(start_nodes) - node.opt_start() - init_intermediate_bounds[node.inputs[0].name] = ([node.inputs[0].lower.detach(), node.inputs[0].upper.detach()]) - - print("alpha-CROWN optimizable variables initialized.") - return l, u, init_intermediate_bounds - - def beta_bias(self): - batch_size = len(self.relus[-1].split_beta) - batch = int(batch_size/2) - bias = torch.zeros((batch_size, 1), device=self.device) - for m in self.relus: - if m.split_beta_used: - bias[:batch] = bias[:batch] + m.split_bias*m.split_beta[:batch]*m.split_c[:batch] - bias[batch:] = bias[batch:] + m.split_bias*m.split_beta[batch:]*m.split_c[batch:] - if m.history_beta_used: - bias = bias + (m.new_history_bias*m.new_history_beta*m.new_history_c).sum(1, keepdim=True) - # No single node split here, because single node splits do not have bias. - return bias - - - def get_optimized_bounds(self, x=None, aux=None, C=None, IBP=False, forward=False, method='backward', - bound_lower=True, bound_upper=False, reuse_ibp=False, return_A=False, final_node_name=None, - average_A=False, new_interval=None, reference_bounds=None, aux_reference_bounds=None, needed_A_dict=None): - # optimize CROWN lower bound by alpha and beta - opts = self.bound_opts['optimize_bound_args'] - iteration = opts['ob_iteration']; beta = opts['ob_beta']; alpha = opts['ob_alpha'] - opt_coeffs = opts['ob_opt_coeffs']; opt_bias = opts['ob_opt_bias'] - verbose = opts['ob_verbose']; opt_choice = opts['ob_optimizer'] - single_node_split = opts['ob_single_node_split'] - keep_best = opts['ob_keep_best']; update_by_layer = opts['ob_update_by_layer']; init = opts['ob_init'] - lr = opts['ob_lr']; lr_beta = opts['ob_lr_beta'] - lr_intermediate_beta = opts['ob_lr_intermediate_beta'] - intermediate_beta_enabled = opts['ob_intermediate_beta'] - lr_decay = opts['ob_lr_decay']; lr_coeffs = opts['ob_lr_coeffs'] - loss_reduction_func = opts['ob_loss_reduction_func'] - stop_criterion_func = opts['ob_stop_criterion_func'] - input_grad = opts['ob_input_grad'] - sparse_intermediate_bounds = self.bound_opts.get('sparse_intermediate_bounds', False) - # verbose = 1 - - assert bound_lower != bound_upper, 'we can only optimize lower OR upper bound at one time' - assert alpha or beta, "nothing to optimize, use compute bound instead!" - - if C is not None: - self.final_shape = C.size()[:2] - self.bound_opts.update({'final_shape': self.final_shape}) - if init: - self.init_slope(x, share_slopes=opts['ob_alpha_share_slopes'], method=method, c=C) - - alphas = [] - betas = [] - parameters = [] - dense_coeffs_mask = [] - - for m in self.optimizable_activations: - if alpha: - alphas.extend(list(m.alpha.values())) - - if alpha: - # Alpha has shape (2, output_shape, batch_dim, node_shape) - parameters.append({'params': alphas, 'lr': lr, 'batch_dim': 2}) - # best_alpha is a dictionary of dictionary. Each key is the alpha variable for one relu layer, and each value is a dictionary contains all relu layers after that layer as keys. - best_alphas = OrderedDict() - for m in self.optimizable_activations: - best_alphas[m.name] = {} - for alpha_m in m.alpha: - best_alphas[m.name][alpha_m] = m.alpha[alpha_m].clone().detach() - # We will directly replace the dictionary for each relu layer after optimization, so the saved alpha might not have require_grad=True. - m.alpha[alpha_m].requires_grad_() - - if beta: - if len(self.relus) != len(self.optimizable_activations): - raise NotImplementedError("Beta-CROWN for tanh models is not supported yet") - - if single_node_split: - for model in self.relus: - betas.append(model.sparse_beta) - else: - betas = self.beta_params + self.single_beta_params - if opt_coeffs: - coeffs = [dense_coeffs["dense"] for dense_coeffs in self.split_dense_coeffs_params] + self.coeffs_params - dense_coeffs_mask = [dense_coeffs["mask"] for dense_coeffs in self.split_dense_coeffs_params] - parameters.append({'params': coeffs, 'lr': lr_coeffs}) - best_coeffs = [coeff.clone().detach() for coeff in coeffs] - if opt_bias: - biases = self.bias_params - parameters.append({'params': biases, 'lr': lr_coeffs}) - best_biases = [bias.clone().detach() for bias in biases] - - # Beta has shape (batch, max_splits_per_layer) - parameters.append({'params': betas, 'lr': lr_beta, 'batch_dim': 0}) - best_betas = [b.clone().detach() for b in betas] - - start = time.time() - - - if opt_choice == "adam-autolr": - opt = AdamElementLR(parameters, lr=lr) - elif opt_choice == "adam": - opt = optim.Adam(parameters, lr=lr) - elif opt_choice == 'sgd': - opt = optim.SGD(parameters, lr=lr, momentum=0.9) - else: - raise NotImplementedError(opt_choice) + node.prior_checked = True + + def compute_intermediate_bounds(self, node, prior_checked=False): + if getattr(node, 'lower', None) is not None: + return - # Create a weight vector to scale learning rate. - loss_weight = torch.ones(size=(x[0].size(0),), device=x[0].device) + logger.debug(f'Getting the bounds of {node}') - scheduler = optim.lr_scheduler.ExponentialLR(opt, lr_decay) + if not prior_checked: + self.check_prior_bounds(node) - last_l = math.inf - last_total_loss = torch.tensor(1e8, device=x[0].device, dtype=x[0].dtype) - best_l = torch.zeros([x[0].shape[0], 1], device=x[0].device, dtype=x[0].dtype) + 1e8 + if not node.perturbed: + fv = self.get_forward_value(node) + node.interval = node.lower, node.upper = fv, fv + return + # FIXME check that weight perturbation is not affected + # (from_input=True should be set for weights) + if not node.from_input and hasattr(node, 'forward_value'): + node.lower = node.upper = self.get_forward_value(node) + return - best_intermediate_bounds = [] # TODO: this should be a dictionary to handle more general architectures. + reference_bounds = self.reference_bounds - if sparse_intermediate_bounds and aux_reference_bounds is None and reference_bounds is not None: - aux_reference_bounds = {} - for name, (lb, ub) in reference_bounds.items(): - aux_reference_bounds[name] = (lb.detach().clone(), ub.detach().clone()) - if aux_reference_bounds is None: - aux_reference_bounds = {} + if self.use_forward: + node.lower, node.upper = self.forward_general( + node=node, concretize=True) + return - for i in range(iteration): - intermediate_constr = None - - if not update_by_layer: - reference_bounds = new_interval # If we still optimize all intermediate neurons, we can use new_interval as reference bounds. - - ret = self.compute_bounds(x, aux, C, method=method, IBP=IBP, forward=forward, - bound_lower=bound_lower, bound_upper=bound_upper, reuse_ibp=reuse_ibp, - return_A=return_A, final_node_name=final_node_name, average_A=average_A, - # If we set neuron bounds individually, or if we are optimizing intermediate layer bounds using beta, we do not set new_interval. - # When intermediate betas are used, we must set new_interval to None because we want to recompute all intermediate layer bounds. - new_interval=partial_new_interval if beta and intermediate_beta_enabled else new_interval if update_by_layer else None, - # This is the currently tightest interval, which will be used to pass split constraints when intermediate betas are used. - reference_bounds=reference_bounds, - # This is the interval used for checking for unstable neurons. - aux_reference_bounds=aux_reference_bounds if sparse_intermediate_bounds else None, - # These are intermediate layer beta variables and their corresponding A matrices and biases. - intermediate_constr=intermediate_constr, needed_A_dict=needed_A_dict) - - if i == 0: - best_ret = ret - for model in self.optimizable_activations: - best_intermediate_bounds.append([model.inputs[0].lower.clone().detach(), model.inputs[0].upper.clone().detach()]) - if sparse_intermediate_bounds: - aux_reference_bounds[model.inputs[0].name] = best_intermediate_bounds[-1] - - ret_l, ret_u = ret[0], ret[1] - - if beta and opt_bias and not single_node_split: - ret_l = ret_l + self.beta_bias() - ret = (ret_l, ret_u) # where is A matrix? - - l = ret_l - if ret_l is not None and ret_l.shape[1] != 1: # Reduction over the spec dimension. - l = loss_reduction_func(ret_l) - u = ret_u - if ret_u is not None and ret_u.shape[1] != 1: - u = loss_reduction_func(ret_u) - - if True: - loss_ = l if bound_lower else -u - stop_criterion = stop_criterion_func(ret_l) if bound_lower else stop_criterion_func(-ret_u) - total_loss = -1 * loss_ - if type(stop_criterion) == bool: - loss = total_loss.sum() * (not stop_criterion) + #FIXME need clean up + + # assign concretized bound for ReLU layer to save computational cost + # FIXME: Put ReLU after reshape will cause problem! + if self.check_IBP_intermediate(node): + # Intermediate bounds for some operators are directly + # computed from their input nodes by IBP + # (such as BoundRelu, BoundNeg) + logger.debug('IBP propagation for intermediate bounds on %s', node) + elif (isinstance(node, BoundReshape) + and hasattr(node.inputs[0], 'lower') + and hasattr(node.inputs[1], 'value')): + # TODO merge this with `check_IBP_intermediate` + # Node for input value. + val_input = node.inputs[0] + # Node for input parameter (e.g., shape, permute) + arg_input = node.inputs[1] + node.lower = node.forward(val_input.lower, arg_input.value) + node.upper = node.forward(val_input.upper, arg_input.value) + node.interval = (node.lower, node.upper) + else: + # For the first linear layer, IBP can give the same tightness + # as CROWN. + if self.check_IBP_first_linear(node): + return + + sparse_intermediate_bounds_with_ibp = self.bound_opts.get( + 'sparse_intermediate_bounds_with_ibp', True) + # Sparse intermediate bounds can be enabled + # if aux_reference_bounds are given. + # (this is enabled for ReLU only, and not for other + # activations.) + sparse_intermediate_bounds = (self.bound_opts.get( + 'sparse_intermediate_bounds', False) + and isinstance(self[node.output_name[0]], BoundRelu)) + + ref_intermediate_lb, ref_intermediate_ub = None, None + if sparse_intermediate_bounds: + if node.name not in self.aux_reference_bounds: + # If aux_reference_bounds are not available, + # we can use IBP to compute these bounds. + if sparse_intermediate_bounds_with_ibp: + with torch.no_grad(): + # Get IBP bounds for this layer; + # we set delete_bounds_after_use=True which does + # not save extra intermediate bound tensors. + ret_ibp = self.IBP_general( + node=node, delete_bounds_after_use=True) + ref_intermediate_lb = ret_ibp[0] + ref_intermediate_ub = ret_ibp[1] + else: + sparse_intermediate_bounds = False else: - loss = (total_loss * stop_criterion.logical_not()).sum() - - with torch.no_grad(): - # Save varibles if this is the best iteration. - if keep_best and (total_loss < best_l).any(): - # we only pick up the results improved in a batch - idx = (total_loss < best_l).squeeze() - best_l[idx] = total_loss[idx] - - if ret[0] is not None: - best_ret[0][idx] = ret[0][idx] - if ret[1] is not None: - best_ret[1][idx] = ret[1][idx] - if return_A: - best_ret = (best_ret[0], best_ret[1], ret[2]) - - for ii, model in enumerate(self.optimizable_activations): - # best_intermediate_bounds.append([model.inputs[0].lower, model.inputs[0].upper]) - best_intermediate_bounds[ii][0][idx] = torch.max(best_intermediate_bounds[ii][0][idx], model.inputs[0].lower[idx]) - best_intermediate_bounds[ii][1][idx] = torch.min(best_intermediate_bounds[ii][1][idx], model.inputs[0].upper[idx]) - if alpha: - # each alpha has shape (2, output_shape, batch, *shape) - for alpha_m in model.alpha: - best_alphas[model.name][alpha_m][:,:,idx] = model.alpha[alpha_m][:,:,idx].clone().detach() - if beta and single_node_split: - best_betas[ii][idx] = betas[ii][idx].clone().detach() - - if not single_node_split and beta: - for ii, b in enumerate(betas): - best_betas[ii][idx] = b[idx].clone().detach() - - if opt_coeffs: - best_coeffs = [co.clone().detach() for co in coeffs] # TODO: idx-wise - if opt_bias: - best_biases = [bias.clone().detach() for bias in biases] # TODO: idx-wise - - - if os.environ.get('AUTOLIRPA_DEBUG_OPT', False): - print(f"****** iter [{i}]", - f"loss: {loss.item()}, lr: {opt.param_groups[0]['lr']}") - - if isinstance(stop_criterion, torch.Tensor) and stop_criterion.all(): - print(f"\nall verified at {i}th iter") - break + aux_bounds = self.aux_reference_bounds[node.name] + ref_intermediate_lb, ref_intermediate_ub = aux_bounds + + sparse_C = self.get_sparse_C( + node, sparse_intermediate_bounds, + ref_intermediate_lb, ref_intermediate_ub) + newC, reduced_dim, unstable_idx, unstable_size = sparse_C + + if unstable_idx is None or unstable_size > 0: + if self.return_A: + node.lower, node.upper, _ = self.backward_general( + C=newC, node=node, unstable_idx=unstable_idx, + unstable_size=unstable_size) + else: + # Compute backward bounds only when there are unstable + # neurons, or when we don't know which neurons are unstable. + node.lower, node.upper = self.backward_general( + C=newC, node=node, unstable_idx=unstable_idx, + unstable_size=unstable_size) + + if reduced_dim: + self.restore_sparse_bounds( + node, unstable_idx, unstable_size, + ref_intermediate_lb, ref_intermediate_ub) + + # node.lower and node.upper (intermediate bounds) are computed in + # the above function. If we have bound references, we set them here + # to always obtain a better set of bounds. + if node.name in reference_bounds: + ref_bounds = reference_bounds[node.name] + # Initially, the reference bound and the computed bound can be + # exactly the same when intermediate layer beta is 0. This will + # prevent gradients flow. So we need a small guard here. + if self.intermediate_constr is not None: + # Intermediate layer beta is used. + # Note that we cannot just take the reference bounds if + # they are better - this makes alphas have zero gradients. + new_lower = 0.9 * ref_bounds[0] + 0.1 * node.lower + new_upper = 0.9 * ref_bounds[1] + 0.1 * node.upper + node.lower = torch.max(new_lower, node.lower) + node.upper = torch.min(new_upper, node.upper) + # Additionally, if the reference bounds say a neuron is + # stable, we always keep it. (FIXME: this is for ReLU only). + lower_stable = ref_bounds[0] >= 0. + node.lower[lower_stable] = ref_bounds[0][lower_stable] + upper_stable = ref_bounds[1] <= 0. + node.upper[upper_stable] = ref_bounds[1][upper_stable] + else: + # Set the intermediate layer bounds using reference bounds, + # always choosing the tighter one. + node.lower = ( + torch.max(ref_bounds[0], node.lower).detach() + - node.lower.detach() + node.lower) + node.upper = ( + node.upper - (node.upper.detach() + - torch.min(ref_bounds[1], node.upper).detach())) + # Otherwise, we only use reference bounds to check which neurons + # are unstable. + + # FIXME (12/28): we should be consistent, and only use + # node.interval, do not use node.lower or node.upper! + node.interval = (node.lower, node.upper) - current_lr = [] - for param_group in opt.param_groups: - current_lr.append(param_group['lr']) - - opt.zero_grad(set_to_none=True) - - if input_grad and x[0].ptb.x_L.grad is not None: - x[0].ptb.x_L.grad = None - x[0].ptb.x_U.grad = None - - loss.backward() - - if verbose > 0: - print(f"*** iter [{i}]\n", f"loss: {loss.item()}", total_loss.squeeze().detach().cpu().numpy(), "lr: ", current_lr) - if beta: - masked_betas = [] - for model in self.relus: - masked_betas.append(model.masked_beta) - if model.history_beta_used: - print(f"{model.name} history beta", model.new_history_beta.squeeze()) - if model.split_beta_used: - print(f"{model.name} split beta:", model.split_beta.view(-1)) - print(f"{model.name} bias:", model.split_bias) - if opt_coeffs: - for co in coeffs: - print(f'coeff sum: {co.abs().sum():.5g}') - if beta and i == 0 and verbose > 0: - breakpoint() - - if opt_choice == "adam-autolr": - opt.step(lr_scale=[loss_weight, loss_weight]) - else: - opt.step() - - if beta: - # Clipping to >=0. - for b in betas: - b.data = (b >= 0) * b.data - for dmi in range(len(dense_coeffs_mask)): - # apply dense mask to the dense split coeffs matrix - coeffs[dmi].data = dense_coeffs_mask[dmi].float() * coeffs[dmi].data - - if alpha: - for m in self.relus: - for m_start_node, v in m.alpha.items(): - v.data = torch.clamp(v.data, 0., 1.) - # print(f'layer {m.name} start_node {m_start_node} shape {v.size()} norm {v[:,:,0].abs().sum()} {v[:,:,-1].abs().sum()} {v.abs().sum()}') - # For tanh, we clip it in bound_ops because clipping depends. TODO: clipping should be a method in the BoundOptimizableActivation class. - # on pre-activation bounds - - # If loss has become worse for some element, reset those to current best. - with torch.no_grad(): - if beta and opt_choice == "adam-autolr" and i > iteration * 0.2: - for ii, model in enumerate(self.relus): - if alpha: - # each alpha has shape (2, output_shape, batch, *shape) - for alpha_m in model.alpha: - model.alpha[alpha_m][:,:,worse_idx] = best_alphas[model.name][alpha_m][:,:,worse_idx].clone().detach() - if beta and single_node_split: - betas[ii][worse_idx] = best_betas[ii][worse_idx].clone().detach() - - scheduler.step() - last_l = loss.item() - last_total_loss = total_loss.detach().clone() - - # if beta and intermediate_beta_enabled and verbose > 0: - if verbose > 0: - breakpoint() - - if keep_best: - # Set all variables to their saved best values. - with torch.no_grad(): - for idx, model in enumerate(self.optimizable_activations): - if alpha: - # Assigns a new dictionary. - model.alpha = best_alphas[model.name] - model.inputs[0].lower.data = best_intermediate_bounds[idx][0].data - model.inputs[0].upper.data = best_intermediate_bounds[idx][1].data - if beta: - if single_node_split: - model.sparse_beta.copy_(best_betas[idx]) - else: - for b, bb in zip(betas, best_betas): - b.data = bb.data - if opt_coeffs: - for co, bco in zip(coeffs, best_coeffs): - co.data = bco.data - if opt_bias: - for bias, bbias in zip(biases, best_biases): - bias.data = bbias.data - - if new_interval is not None and not update_by_layer: - for l in self._modules.values(): - if l.name in new_interval.keys() and hasattr(l, "lower"): - # l.interval = tuple(new_interval[l.name][:2]) - l.lower = torch.max(l.lower, new_interval[l.name][0]) - l.upper = torch.min(l.upper, new_interval[l.name][1]) - infeasible_neurons = l.lower > l.upper - if infeasible_neurons.any(): - print('infeasible!!!!!!!!!!!!!!', infeasible_neurons.sum().item(), infeasible_neurons.nonzero()[:, 0]) - - print("best_l after optimization:", best_l.sum().item(), "with beta sum per layer:", [p.sum().item() for p in betas]) - # np.save('solve_slope.npy', np.array(record)) - print('optimal alpha/beta time:', time.time() - start) - return best_ret - - - def get_unstable_conv_locations(self, node, aux_reference_bounds): - # For conv layer we only check the case where all neurons are active/inactive. - unstable_masks = torch.logical_and(aux_reference_bounds[node.name][0] < 0, aux_reference_bounds[node.name][1] > 0) - # For simplicity, merge unstable locations for all elements in this batch. TODO: use individual unstable mask. - # It has shape (H, W) indicating if a neuron is unstable/stable. - # TODO: so far we merge over the batch dimension to allow easier implementation. - unstable_locs = unstable_masks.sum(dim=0).bool() - # Now converting it to indices for these unstable nuerons. These are locations (i,j) of unstable neurons. - unstable_idx = unstable_locs.nonzero(as_tuple=True) - # Number of unstable neurons. - unstable_size = unstable_idx[0].numel() - # print(f'layer {node.name} unstable_size {unstable_size} actual {unstable_masks.sum().item()} size {node.output_shape}') - # We sum over the channel direction, so need to multiply that. - return unstable_idx, unstable_size - - - def get_unstable_locations(self, node, aux_reference_bounds): - # FIXME (09/19): this is for ReLU only! - unstable_masks = torch.logical_and(aux_reference_bounds[node.name][0] < 0, aux_reference_bounds[node.name][1] > 0) - # unstable_masks = torch.ones(dtype=torch.bool, size=(batch_size, dim), device=self.device) - if unstable_masks.ndim > 2: - # Flatten the conv layer shape. - unstable_masks = unstable_masks.view(unstable_masks.size(0), -1) - # For simplicity, merge unstable locations for all elements in this batch. TODO: use individual unstable mask. - unstable_locs = unstable_masks.sum(dim=0).bool() - # This is a 1-d indices, shared by all elements in this batch. - unstable_idx = unstable_locs.nonzero().squeeze(1) - unstable_size = unstable_idx.numel() - # print(f'layer {node.name} unstable {unstable_size} total {node.output_shape}') - return unstable_idx, unstable_size - - def compute_bounds(self, x=None, aux=None, C=None, method='backward', IBP=False, forward=False, - bound_lower=True, bound_upper=True, reuse_ibp=False, - return_A=False, needed_A_dict=None, final_node_name=None, average_A=False, new_interval=None, - return_b=False, b_dict=None, reference_bounds=None, intermediate_constr=None, alpha_idx=None, - aux_reference_bounds=None, need_A_only=False): + def merge_A_dict(self, lA_dict, uA_dict): + merged_A = {} + for output_node_name in lA_dict: + merged_A[output_node_name] = {} + lA_dict_ = lA_dict[output_node_name] + uA_dict_ = uA_dict[output_node_name] + for input_node_name in lA_dict_: + merged_A[output_node_name][input_node_name] = { + 'lA': lA_dict_[input_node_name]['lA'], + 'uA': uA_dict_[input_node_name]['uA'], + 'lbias': lA_dict_[input_node_name]['lbias'], + 'ubias': uA_dict_[input_node_name]['ubias'], + } + return merged_A + + def compute_bounds( + self, x=None, aux=None, C=None, method='backward', IBP=False, + forward=False, bound_lower=True, bound_upper=True, reuse_ibp=False, + reuse_alpha=False, return_A=False, needed_A_dict=None, + final_node_name=None, average_A=False, + intermediate_layer_bounds=None, reference_bounds=None, + intermediate_constr=None, alpha_idx=None, + aux_reference_bounds=None, need_A_only=False, + cutter=None, decision_thresh=None, + update_mask=None): r"""Main function for computing bounds. Args: - x (tuple or None): Input to the model. If it is None, the input from the last - `forward` or `compute_bounds` call is reused. Otherwise: the number of elements in the tuple should be - equal to the number of input nodes in the model, and each element in the tuple - corresponds to the value for each input node respectively. It should look similar - as the `global_input` argument when used for creating a `BoundedModule`. - - aux (object, optional): Auxliary information that can be passed to `Perturbation` - classes for initializing and concretizing bounds, e.g., additional information - for supporting synonym word subsitution perturbaiton. - - C (Tensor): The specification matrix that can map the output of the model with an - additional linear layer. This is usually used for maping the logits output of the - model to classification margins. - - method (str): The main method for bound computation. Choices: + x (tuple or None): Input to the model. If it is None, the input + from the last `forward` or `compute_bounds` call is reused. + Otherwise: the number of elements in the tuple should be + equal to the number of input nodes in the model, and each element in + the tuple corresponds to the value for each input node respectively. + It should look similar as the `global_input` argument when used for + creating a `BoundedModule`. + + aux (object, optional): Auxliary information that can be passed to + `Perturbation` classes for initializing and concretizing bounds, + e.g., additional information for supporting synonym word subsitution + perturbaiton. + + C (Tensor): The specification matrix that can map the output of the + model with an additional linear layer. This is usually used for + maping the logits output of the model to classification margins. + + method (str): The main method for bound computation. Choices: * `IBP`: purely use Interval Bound Propagation (IBP) bounds. - * `CROWN-IBP`: use IBP to compute intermediate bounds, but use CROWN (backward mode LiRPA) to compute the bounds of the final node. - * `CROWN`: purely use CROWN to compute bounds for intermediate nodes and the final node. - * `Forward`: purely use forward mode LiRPA to compute the bounds. - * `Forward+Backward`: use forward mode LiRPA to compute bounds for intermediate nodes, but further use CROWN to compute bounds for the final node. - * `CROWN-Optimized` or `alpha-CROWN`: use CROWN, and also optimize the linear relaxation parameters for activations. - - IBP (bool, optional): If `True`, use IBP to compute the bounds of intermediate nodes. - It can be automatically set according to `method`. - - forward (bool, optional): If `True`, use the forward mode bound propagation to compute the bounds - of intermediate nodes. It can be automatically set according to `method`. - - bound_lower (bool, default `True`): If `True`, the lower bounds of the output needs to be computed. - - bound_upper (bool, default `True`): If `True`, the upper bounds of the output needs to be computed. - - reuse_ibp (bool, optional): If `True` and `method` is None, reuse the previously saved IBP bounds. + * `CROWN-IBP`: use IBP to compute intermediate bounds, + but use CROWN (backward mode LiRPA) to compute the bounds of the + final node. + * `CROWN`: purely use CROWN to compute bounds for intermediate + nodes and the final node. + * `Forward`: purely use forward mode LiRPA. + * `Forward+Backward`: use forward mode LiRPA for intermediate + nodes, but further use CROWN for the final node. + * `CROWN-Optimized` or `alpha-CROWN`: use CROWN, and also + optimize the linear relaxation parameters for activations. + * `forward-optimized`: use forward bounds with optimized linear + relaxation. + + IBP (bool, optional): If `True`, use IBP to compute the bounds of + intermediate nodes. It can be automatically set according to + `method`. + + forward (bool, optional): If `True`, use the forward mode bound + propagation to compute the bounds of intermediate nodes. It can be + automatically set according to `method`. + + bound_lower (bool, default `True`): If `True`, the lower bounds of + the output needs to be computed. + + bound_upper (bool, default `True`): If `True`, the upper bounds of + the output needs to be computed. + + reuse_ibp (bool, optional): If `True` and `method` is None, reuse + the previously saved IBP bounds. + + final_node_name (str, optional): Set the final node in the + computational graph for bound computation. By default, the final + node of the originally built computational graph is used. + + return_A (bool, optional): If `True`, return linear coefficients + in bound propagation (`A` tensors) with `needed_A_dict` set. + + needed_A_dict (dict, optional): A dictionary specifying linear + coefficients (`A` tensors) that are needed and should be returned. + Each key in the dictionary is the name of a starting node in + backward bound propagation, with a list as the value for the key, + which specifies the names of the ending nodes in backward bound + propagation, and the linear coefficients of the starting node w.r.t. + the specified ending nodes are returned. By default, it is empty. + + reuse_alpha (bool, optional): If `True`, reuse previously saved + alpha values when they are not being optimized. + + decision_thresh (float, optional): In CROWN-optimized mode, we will + use this decision_thresh to dynamically optimize those domains that + <= the threshold. + + intermediate_layer_bounds: A dictionary of 2-element tuple/list + containing lower and upper bounds for intermediate layers. + The dictionary keys should include the names of the layers whose + bounds should be set without recomputation. The layer names can be + viewed by setting environment variable AUTOLIRPA_DEBUG_GRAPH=1. + The values of each dictionary elements are (lower_bounds, + upper_bounds) where "lower_bounds" and "upper_bounds" are two + tensors with the same shape as the output shape of this layer. If + you only need to set intermediate layer bounds for certain layers, + then just include these layers' names in the dictionary. + + reference_bounds: Format is similar to "intermediate_layer_bounds". + However, these bounds are only used as a reference, and the bounds + for intermediate layers will still be computed (e.g., using CROWN, + IBP or other specified methods). The computed bounds will be + compared to "reference_bounds" and the tighter one between the two + will be used. + + aux_reference_bounds: Format is similar to intermediate layer + bounds. However, these bounds are only used for determine which + neurons are stable and which neurons are unstable for ReLU networks. + Unstable neurons' intermediate layer bounds will be recomputed. Returns: - bound (tuple): a tuple of computed lower bound and upper bound respectively. + bound (tuple): When `return_A` is `False`, return a tuple of + the computed lower bound and upper bound. When `return_A` + is `True`, return a tuple of lower bound, upper bound, and + `A` dictionary. """ + logger.debug(f'Compute bounds with {method}') + if needed_A_dict is None: needed_A_dict = {} if not bound_lower and not bound_upper: - raise ValueError('At least one of bound_lower and bound_upper must be True') + raise ValueError( + 'At least one of bound_lower and bound_upper must be True') # Several shortcuts. method = method.lower() if method is not None else method @@ -1079,951 +1142,233 @@ def compute_bounds(self, x=None, aux=None, C=None, method='backward', IBP=False, method = 'backward' elif method == 'forward': forward = True - elif method == 'forward+backward': + elif method == 'forward+backward' or method == 'forward+crown': method = 'backward' forward = True - elif method in ['crown-optimized', 'alpha-crown']: - assert return_A is False - ret = [] + elif method in ['crown-optimized', 'alpha-crown', 'forward-optimized']: + # Lower and upper bounds need two separate rounds of optimization. + if method == 'forward-optimized': + method = 'forward' + else: + method = 'backward' if bound_lower: - ret1 = self.get_optimized_bounds(x=x, IBP=False, C=C, method='backward', new_interval=new_interval, reference_bounds=reference_bounds, - bound_lower=bound_lower, bound_upper=False, return_A=return_A, aux_reference_bounds=aux_reference_bounds, - needed_A_dict=needed_A_dict) + ret1 = self.get_optimized_bounds( + x=x, C=C, method=method, + intermediate_layer_bounds=intermediate_layer_bounds, + reference_bounds=reference_bounds, bound_lower=bound_lower, + bound_upper=False, return_A=return_A, + aux_reference_bounds=aux_reference_bounds, + needed_A_dict=needed_A_dict, + final_node_name=final_node_name, + cutter=cutter, decision_thresh=decision_thresh) if bound_upper: - ret2 = self.get_optimized_bounds(x=x, IBP=False, C=C, method='backward', new_interval=new_interval, reference_bounds=reference_bounds, - bound_lower=False, bound_upper=bound_upper, return_A=return_A, aux_reference_bounds=aux_reference_bounds, - needed_A_dict=needed_A_dict) + ret2 = self.get_optimized_bounds( + x=x, C=C, method=method, + intermediate_layer_bounds=intermediate_layer_bounds, + reference_bounds=reference_bounds, bound_lower=False, + bound_upper=bound_upper, return_A=return_A, + aux_reference_bounds=aux_reference_bounds, + needed_A_dict=needed_A_dict, + final_node_name=final_node_name, + cutter=cutter, decision_thresh=decision_thresh) if bound_lower and bound_upper: - return ret1[0], ret2[1] + if return_A: + # Needs to merge the A dictionary. + lA_dict = ret1[2] + uA_dict = ret2[2] + merged_A = self.merge_A_dict(lA_dict, uA_dict) + return ret1[0], ret2[1], merged_A + else: + return ret1[0], ret2[1] elif bound_lower: - return ret1 + return ret1 # ret1[1] is None. elif bound_upper: - return ret2 + return ret2 # ret2[0] is None. if reference_bounds is None: reference_bounds = {} if aux_reference_bounds is None: aux_reference_bounds = {} - # If y in self.backward_node_pairs[x], then node y is visited when + # If y in self.backward_node_pairs[x], then node y is visited when # doing backward bound propagation starting from node x. self.backward_from = dict([(node, []) for node in self._modules]) if not bound_lower and not bound_upper: - raise ValueError('At least one of bound_lower and bound_upper in compute_bounds should be True') + raise ValueError( + 'At least one of bound_lower and bound_upper in compute_bounds ' + 'should be True') A_dict = {} if return_A else None if x is not None: - self._set_input(*x, new_interval=new_interval) + self._set_input( + *x, intermediate_layer_bounds=intermediate_layer_bounds) if IBP and method is None and reuse_ibp: # directly return the previously saved ibp bounds return self.ibp_lower, self.ibp_upper - root = [self._modules[name] for name in self.root_name] + root = [self[name] for name in self.root_name] batch_size = root[0].value.shape[0] dim_in = 0 for i in range(len(root)): value = root[i].forward() - if hasattr(root[i], 'perturbation') and root[i].perturbation is not None: - root[i].linear, root[i].center, root[i].aux = \ - root[i].perturbation.init(value, aux=aux, forward=forward) - # This input/parameter has perturbation. Create an interval object. - if self.ibp_relative: - root[i].interval = Interval( - None, None, root[i].linear.nominal, root[i].linear.lower_offset, root[i].linear.upper_offset) - else: - root[i].interval = Interval( - root[i].linear.lower, root[i].linear.upper, ptb=root[i].perturbation) + if getattr(root[i], 'perturbation', None) is not None: + ret_init = root[i].perturbation.init( + value, aux=aux, forward=forward) + root[i].linear, root[i].center, root[i].aux = ret_init + # This input/parameter has perturbation. + # Create an interval object. + root[i].interval = Interval( + root[i].linear.lower, root[i].linear.upper, + ptb=root[i].perturbation) if forward: root[i].dim = root[i].linear.lw.shape[1] dim_in += root[i].dim else: - if self.ibp_relative: - root[i].interval = Interval( - None, None, value, torch.zeros_like(value), torch.zeros_like(value)) - else: - # This inpute/parameter does not has perturbation. - # Use plain tuple defaulting to Linf perturbation. - root[i].interval = (value, value) - root[i].forward_value = root[i].forward_value = root[i].value = root[i].lower = root[i].upper = value + # This inpute/parameter does not has perturbation. + # Use plain tuple defaulting to Linf perturbation. + root[i].interval = (value, value) + root[i].forward_value = root[i].value = value + root[i].lower = root[i].upper = value - if self.ibp_relative: - root[i].lower, root[i].upper = root[i].interval.lower, root[i].interval.upper - else: - root[i].lower, root[i].upper = root[i].interval + root[i].lower, root[i].upper = root[i].interval if forward: - self._init_forward(root, dim_in) + self.init_forward(root, dim_in) - final = self._modules[self.final_name] if final_node_name is None else self._modules[final_node_name] - logger.debug('Final node {}[{}]'.format(final, final.name)) + final = self.final_node( + ) if final_node_name is None else self[final_node_name] + logger.debug(f'Final node {final.__class__.__name__}({final.name})') if IBP: - res = self._IBP_general(node=final, C=C) - if self.ibp_relative: - self.ibp_lower, self.ibp_upper = res.lower, res.upper - else: - self.ibp_lower, self.ibp_upper = res + self.ibp_lower, self.ibp_upper = self.IBP_general(node=final, C=C) if method is None: - return self.ibp_lower, self.ibp_upper + return self.ibp_lower, self.ibp_upper if C is None: - # C is an identity matrix by default + # C is an identity matrix by default if final.output_shape is None: - raise ValueError('C is not provided while node {} has no default shape'.format(final.shape)) + raise ValueError( + f'C is not missing while node {final} has no default shape') dim_output = int(prod(final.output_shape[1:])) # TODO: use an eyeC object here. - C = torch.eye(dim_output, device=self.device).expand(batch_size, dim_output, dim_output) + C = torch.eye(dim_output, device=self.device).expand( + batch_size, dim_output, dim_output) + + # Reuse previously saved alpha values, + # even if they are not optimized now + if reuse_alpha: + for node in self.optimizable_activations: + node.opt_reuse() + else: + for node in self.optimizable_activations: + node.opt_no_reuse() + + # Inject update mask inside the activations + # update_mask: None or bool tensor([batch_size]) + # If set to a tensor, only update the alpha and beta of selected + # element (with element=1). + + if update_mask is None: + for node in self.optimizable_activations: + node.clean_alpha_beta_update_mask() + else: + for node in self.optimizable_activations: + node.set_alpha_beta_update_mask(update_mask) - # check whether weights are perturbed and set nonlinear for the BoundMatMul operation for n in self._modules.values(): - if type(n) in [BoundLinear, BoundConv, BoundBatchNormalization]: + # Check whether all prior intermediate bounds already exist + n.prior_checked = False + # check whether weights are perturbed and set nonlinear for the + # BoundMatMul operation + if isinstance(n, (BoundLinear, BoundConv, BoundBatchNormalization)): n.nonlinear = False - for l_name in n.input_name[1:]: - node = self._modules[l_name] + for node in n.inputs[1:]: if hasattr(node, 'perturbation'): if node.perturbation is not None: n.nonlinear = True + if isinstance(i, BoundRelu): + for node in i.inputs: + if isinstance(node, BoundConv): + # whether this Conv is followed by a ReLU + node.relu_followed = True # BFS to find out whether each node is used given the current final node - if final != self.last_final_node: - self.last_final_node = final + self._set_used_nodes(final) + + # FIXME clean + self.use_forward = forward + self.root = root + self.batch_size = batch_size + self.dim_in = dim_in + self.return_A = return_A + self.A_dict = A_dict + self.needed_A_dict = needed_A_dict + self.intermediate_constr = intermediate_constr + self.reference_bounds = reference_bounds + self.aux_reference_bounds = aux_reference_bounds + self.final_node_name = final.name + + self.check_prior_bounds(final) + + if method == 'backward': + # This is for the final output bound. + # No need to pass in intermediate layer beta constraints. + ret = self.backward_general( + C=C, node=final, + bound_lower=bound_lower, bound_upper=bound_upper, + average_A=average_A, need_A_only=need_A_only, + unstable_idx=alpha_idx, update_mask=update_mask) + # FIXME when C is specified, lower and upper should not be saved to + # final.lower and final.upper, because they are not the bounds for + # the node. + final.lower, final.upper = ret[0], ret[1] + return ret + elif method == 'forward': + return self.forward_general(C=C, node=final, concretize=True) + else: + raise NotImplementedError + + def _set_used_nodes(self, final): + if final.name != self.last_final_node: + self.last_final_node = final.name + self.used_nodes = [] for i in self._modules.values(): i.used = False final.used = True queue = deque([final]) while len(queue) > 0: n = queue.popleft() - for n_pre_name in n.input_name: - n_pre = self._modules[n_pre_name] + self.used_nodes.append(n) + for n_pre in n.inputs: if not n_pre.used: n_pre.used = True queue.append(n_pre) - for i in self._modules.values(): - if isinstance(i, BoundRelu): - for l_name in i.input_name: - node = self._modules[l_name] - if isinstance(node, BoundConv): - node.relu_followed = True # whether this Conv is followed by a ReLU - - for i in self._modules.values(): # for all nodes - if not i.used: - continue - if hasattr(i, 'nonlinear') and i.nonlinear: - for l_name in i.input_name: - node = self._modules[l_name] - if not hasattr(node, 'lower'): - assert not IBP, 'There should be no missing intermediate bounds when IBP is enabled' - if not node.perturbed and hasattr(node, 'forward_value'): - node.interval = node.lower, node.upper = \ - node.forward_value, node.forward_value - continue - # FIXME check that weight perturbation is not affected - # (from_input=True should be set for weights) - if not node.from_input and hasattr(node, 'forward_value'): - node.lower = node.upper = node.forward_value - continue - if forward: - l, u = self._forward_general( - node=node, root=root, dim_in=dim_in, concretize=True) - else: - # assign concretized bound for ReLU layer to save computational cost - # FIXME: Put ReLU after reshape will cause problem! - if (isinstance(node, BoundActivation) or isinstance(node, BoundTranspose)) and hasattr( - self._modules[node.input_name[0]], 'lower'): - node.lower = node.forward(self._modules[node.input_name[0]].lower) - node.upper = node.forward(self._modules[node.input_name[0]].upper) - elif isinstance(node, BoundReshape) and \ - hasattr(self._modules[node.input_name[0]], 'lower') and \ - hasattr(self._modules[node.input_name[1]], 'value'): - # Node for input value. - val_input = self._modules[node.input_name[0]] - # Node for input parameter (e.g., shape, permute) - arg_input = self._modules[node.input_name[1]] - node.lower = node.forward(val_input.lower, arg_input.value) - node.upper = node.forward(val_input.upper, arg_input.value) - else: - first_layer_flag = False - # This is the list of all intermediate layers where we need to refine. - if intermediate_constr is not None: - intermediate_beta_enabled_layers = [k for v in intermediate_constr.values() for k in v] - else: - intermediate_beta_enabled_layers = [] - # Here we avoid creating a big C matrix in the first linear layer. - # Disable this optimization when we have beta for intermediate layer bounds. - if type(node) == BoundLinear or type(node) == BoundConv and node.name not in intermediate_beta_enabled_layers: - for l_pre in node.input_name: - if type(self._modules[l_pre]) == BoundInput: - node.lower, node.upper = self._IBP_general(node) - first_layer_flag = True - break - if not first_layer_flag: - reduced_dim = False # Only partial neurons (unstable neurons) are bounded. - unstable_idx = None - unstable_size = 99999 - dim = int(prod(node.output_shape[1:])) - sparse_intermediate_bounds = node.name in aux_reference_bounds and self.bound_opts.get('sparse_intermediate_bounds', False) and isinstance(self._modules[node.output_name[0]], BoundRelu) - sparse_conv_intermediate_bounds = self.bound_opts.get('sparse_conv_intermediate_bounds', False) - # FIXME: C matrix shape incorrect for BoundParams. - if (isinstance(node, BoundLinear) or isinstance(node, BoundMatMul)) and int( - os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0: - if sparse_intermediate_bounds: - # If we are doing bound refinement and reference bounds are given, we only refine unstable neurons. - # Also, if we are checking against LP solver we will refine all neurons and do not use this optimization. - # For each batch element, we find the unstable neurons. - unstable_idx, unstable_size = self.get_unstable_locations(node, aux_reference_bounds) - if unstable_size == 0: - # Do nothing, no bounds will be computed. - reduced_dim = True - unstable_idx = [] - elif unstable_size < 0.9 * dim and unstable_size > 0: - # Create an abstract C matrix, the unstable_idx are the non-zero elements in specifications for all batches. - newC = OneHotC([batch_size, unstable_size, *node.output_shape[1:]], self.device, unstable_idx, None) - reduced_dim = True - else: - unstable_idx = None - if not reduced_dim: - newC = eyeC([batch_size, dim, *node.output_shape[1:]], self.device) - elif (isinstance(node, BoundConv) or isinstance(node, - BoundBatchNormalization)) and node.mode == "patches": - if sparse_intermediate_bounds: - unstable_idx, unstable_size = self.get_unstable_conv_locations(node, aux_reference_bounds) - if unstable_size == 0: - # Do nothing, no bounds will be computed. - reduced_dim = True - unstable_idx = [] - # We sum over the channel direction, so need to multiply that. - elif sparse_conv_intermediate_bounds and unstable_size < 0.8 * dim: - # Create an abstract C matrix, the unstable_idx are the non-zero elements in specifications for all batches. - # The shape of patches is [unstable_size, batch, C, H, W]. - newC = Patches(patches=None, stride=1, padding=0, shape=[ - unstable_size, batch_size, node.output_shape[-3], 1, 1], - identity=1, unstable_idx=unstable_idx, output_shape=node.output_shape) - reduced_dim = True - else: - unstable_idx = None - # Here we create an Identity Patches object - if not reduced_dim: - newC = Patches(None, 1, 0, - [node.output_shape[-3], batch_size, node.output_shape[-2], node.output_shape[-1], - node.output_shape[-3], 1, 1], 1, output_shape=node.output_shape) - elif isinstance(node, BoundAdd) and node.mode == "patches": - if sparse_intermediate_bounds: - unstable_idx, unstable_size = self.get_unstable_conv_locations(node, aux_reference_bounds) - if unstable_size == 0: - # Do nothing, no bounds will be computed. - reduced_dim = True - unstable_idx = [] - elif sparse_conv_intermediate_bounds and unstable_size < 0.8 * dim: - num_channel = node.output_shape[-3] - # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1). - patches = (torch.eye(num_channel, device=self.device)).view(num_channel, 1, 1, 1, num_channel, 1, 1) - # Expand to (out_c, 1, unstable_size, out_c, 1, 1). - patches = patches.expand(-1, 1, node.output_shape[-2], node.output_shape[-1], -1, 1, 1) - patches = patches[unstable_idx[0], :, unstable_idx[1], unstable_idx[2]] - # Expand with the batch dimension. Final shape (unstable_size, batch_size, out_c, 1, 1). - patches = patches.expand(-1, batch_size, -1, -1, -1) - newC = Patches(patches, 1, 0, patches.shape, unstable_idx=unstable_idx, output_shape=node.output_shape) - reduced_dim = True - else: - unstable_idx = None - if not reduced_dim: - num_channel = node.output_shape[-3] - # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1). - patches = (torch.eye(num_channel, device=self.device)).view(num_channel, 1, 1, 1, num_channel, 1, 1) - # Expand to (out_c, batch, out_h, out_w, out_c, 1, 1). - patches = patches.expand(-1, batch_size, node.output_shape[-2], node.output_shape[-1], -1, 1, 1) - newC = Patches(patches, 1, 0, patches.shape, output_shape=node.output_shape) - else: - if sparse_intermediate_bounds: - unstable_idx, unstable_size = self.get_unstable_locations(node, aux_reference_bounds) - if unstable_size == 0: - # Do nothing, no bounds will be computed. - reduced_dim = True - unstable_idx = [] - # Number of unstable neurons after merging. - elif unstable_size < 0.9 * dim: - # Create a C matrix. - newC = torch.zeros([1, unstable_size, dim], device=self.device) - # Fill the corresponding elements to 1.0 - newC[0, torch.arange(unstable_size), unstable_idx] = 1.0 - newC = newC.expand(batch_size, -1, -1).view(batch_size, unstable_size, *node.output_shape[1:]) - reduced_dim = True - # print(f'layer {node.name} total {dim} unstable {unstable_size} newC {newC.size()}') - else: - unstable_idx = None - if not reduced_dim: - if dim > 1000: - warnings.warn(f"Creating an identity matrix with size {dim}x{dim} for node {node}. This may indicate poor performance for bound computation. If you see this message on a small network please submit a bug report.", stacklevel=2) - newC = torch.eye(dim, device=self.device) \ - .unsqueeze(0).expand(batch_size, -1, -1) \ - .view(batch_size, dim, *node.output_shape[1:]) - if unstable_idx is None or unstable_size > 0: - # Compute backward bounds only when there are unstable neurons, or when we don't know which neurons are unstable. - self._backward_general(C=newC, node=node, root=root, return_A=False, intermediate_constr=intermediate_constr, unstable_idx=unstable_idx) - - if reduced_dim: - if unstable_size > 0: - # If we only calculated unstable neurons, we need to scatter the results back based on reference bounds. - if isinstance(unstable_idx, tuple): - new_lower = aux_reference_bounds[node.name][0].detach().clone() - new_upper = aux_reference_bounds[node.name][1].detach().clone() - # Conv layer with patches, the unstable_idx is a 3-element tuple for 3 indices (C, H,W) of unstable neurons. - new_lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = node.lower - new_upper[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = node.upper - else: - # Other layers. - new_lower = aux_reference_bounds[node.name][0].detach().clone().view(batch_size, -1) - new_upper = aux_reference_bounds[node.name][1].detach().clone().view(batch_size, -1) - new_lower[:, unstable_idx] = node.lower.view(batch_size, -1) - new_upper[:, unstable_idx] = node.upper.view(batch_size, -1) - # print(f'{node.name} {node} bound diff {(new_lower.view(-1) - aux_reference_bounds[node.name][0].view(-1)).abs().sum()} {(new_upper.view(-1) - aux_reference_bounds[node.name][1].view(-1)).abs().sum()}') - node.lower = new_lower.view(batch_size, *node.output_shape[1:]) - node.upper = new_upper.view(batch_size, *node.output_shape[1:]) - else: - # No unstable neurons. Skip the update. - node.lower = aux_reference_bounds[node.name][0].detach().clone() - node.upper = aux_reference_bounds[node.name][1].detach().clone() - # node.lower and node.upper (intermediate bounds) are computed in the above function. - # If we have bound references, we set them here to always obtain a better set of bounds. - if node.name in reference_bounds: - # Initially, the reference bound and the computed bound can be exactly the same when intermediate layer beta is 0. This will prevent gradients flow. So we need a small guard here. - if intermediate_constr is not None: - # Intermediate layer beta is used. - # Note that we cannot just take the reference bounds if they are better - this makes alphas have zero gradients. - node.lower = torch.max((0.9 * reference_bounds[node.name][0] + 0.1 * node.lower), node.lower) - node.upper = torch.min((0.9 * reference_bounds[node.name][1] + 0.1 * node.upper), node.upper) - # Additionally, if the reference bounds say a neuron is stable, we always keep it. (FIXME: this is for ReLU only). - lower_stable = reference_bounds[node.name][0] >= 0. - node.lower[lower_stable] = reference_bounds[node.name][0][lower_stable] - upper_stable = reference_bounds[node.name][1] <= 0. - node.upper[upper_stable] = reference_bounds[node.name][1][upper_stable] - else: - # MIP solved intermediate layer bounds. - # Set the intermediate layer bounds using reference bounds, always choosing the tighter one. - node.lower = torch.max(reference_bounds[node.name][0] - 1e-5, node.lower) - node.upper = torch.min(reference_bounds[node.name][1] + 1e-5, node.upper) - # Otherwise, we only use reference bounds to check which neurons are unstable. - - if method == 'backward': - # This is for the final output bound. No need to pass in intermediate layer beta constraints. - return self._backward_general(C=C, node=final, root=root, bound_lower=bound_lower, bound_upper=bound_upper, - return_A=return_A, needed_A_dict=needed_A_dict, average_A=average_A, A_dict=A_dict, - return_b=return_b, b_dict=b_dict, unstable_idx=alpha_idx, need_A_only=need_A_only) - elif method == 'forward': - return self._forward_general(C=C, node=final, root=root, dim_in=dim_in, concretize=True) - else: - raise NotImplementedError - - """ improvement on merging BoundLinear, BoundGatherElements and BoundSub - when loss fusion is used in training""" - def _IBP_loss_fusion(self, node, C): - # not using loss fusion - if not self.bound_opts.get('loss_fusion', False): - return None - - # Currently this function has issues in more complicated networks. - if self.bound_opts.get('no_ibp_loss_fusion', False): - return None - - if C is None and isinstance(node, BoundSub): - node_gather = self._modules[node.input_name[1]] - if isinstance(node_gather, BoundGatherElements) or isinstance(node_gather, BoundGatherAten): - node_linear = self._modules[node.input_name[0]] - node_start = self._modules[node_linear.input_name[0]] - if isinstance(node_linear, BoundLinear): - w = self._modules[node_linear.input_name[1]].param - b = self._modules[node_linear.input_name[2]].param - labels = self._modules[node_gather.input_name[1]] - if not hasattr(node_start, 'interval'): - self._IBP_general(node_start) - for inp in node_gather.input_name: - n = self._modules[inp] - if not hasattr(n, 'interval'): - self._IBP_general(n) - if torch.isclose(labels.lower, labels.upper, 1e-8).all(): - labels = labels.lower - batch_size = labels.shape[0] - w = w.expand(batch_size, *w.shape) - w = w - torch.gather(w, dim=1, - index=labels.unsqueeze(-1).repeat(1, w.shape[1], w.shape[2])) - b = b.expand(batch_size, *b.shape) - b = b - torch.gather(b, dim=1, - index=labels.repeat(1, b.shape[1])) - lower, upper = node_start.interval - lower, upper = lower.unsqueeze(1), upper.unsqueeze(1) - node.lower, node.upper = node_linear.interval_propagate( - (lower, upper), (w, w), (b.unsqueeze(1), b.unsqueeze(1))) - node.interval = node.lower, node.upper = node.lower.squeeze(1), node.upper.squeeze(1) - return node.interval - return None - - def _IBP_general(self, node=None, C=None): - if self.bound_opts.get('loss_fusion', False): - res = self._IBP_loss_fusion(node, C) - if res is not None: - return res - - if not node.perturbed and hasattr(node, 'forward_value'): - node.lower, node.upper = node.interval = (node.forward_value, node.forward_value) - if self.ibp_relative: - node.interval = Interval( - None, None, - nominal=node.forward_value, - lower_offset=torch.zeros_like(node.forward_value), - upper_offset=torch.zeros_like(node.forward_value)) - - if not hasattr(node, 'interval'): - for n in node.inputs: - if not hasattr(n, 'interval'): - self._IBP_general(n) - inp = [n_pre.interval for n_pre in node.inputs] - if C is not None and isinstance(node, BoundLinear) and not node.is_input_perturbed(1): - # merge the last BoundLinear node with the specification, available when - # weights of this layer are not perturbed - return node.interval_propagate(*inp, C=C) - else: - node.interval = node.interval_propagate(*inp) - - if self.ibp_relative: - node.lower, node.upper = node.interval.lower, node.interval.upper - else: - node.lower, node.upper = node.interval - if isinstance(node.lower, torch.Size): - node.lower = torch.tensor(node.lower) - node.interval = (node.lower, node.upper) - if isinstance(node.upper, torch.Size): - node.upper = torch.tensor(node.upper) - node.interval = (node.lower, node.upper) - - if C is not None: - return BoundLinear.interval_propagate(None, node.interval, C=C) - else: - return node.interval - - def _addA(self, A1, A2): - """ Add two A (each of them is either Tensor or Patches) """ - if type(A1) == torch.Tensor and type(A1) == torch.Tensor: - return A1 + A2 - elif type(A1) == Patches and type(A2) == Patches: - # Here we have to merge two patches, and if A1.stride != A2.stride, the patches will become a matrix, - # in this case, we will avoid using this mode - assert A1.stride == A2.stride, "A1.stride should be the same as A2.stride, otherwise, please use the matrix mode" - if A1.unstable_idx is not None or A2.unstable_idx is not None: - if A1.unstable_idx is not A2.unstable_idx: # Same tuple object. - raise ValueError('Please set bound option "sparse_conv_intermediate_bounds" to False to run this model.') - assert A1.output_shape == A2.output_shape - # change paddings to merge the two patches - if A1.padding != A2.padding: - if A1.padding > A2.padding: - A2 = A2._replace(patches=F.pad(A2.patches, ( - A1.padding - A2.padding, A1.padding - A2.padding, A1.padding - A2.padding, - A1.padding - A2.padding))) - else: - A1 = A1._replace(patches=F.pad(A1.patches, ( - A2.padding - A1.padding, A2.padding - A1.padding, A2.padding - A1.padding, - A2.padding - A1.padding))) - sum_ret = A1.patches + A2.patches - return Patches(sum_ret, A2.stride, max(A1.padding, A2.padding), sum_ret.shape, unstable_idx=A1.unstable_idx, output_shape=A1.output_shape) - else: - if type(A1) == Patches: - pieces = A1.patches - stride = A1.stride - padding = A1.padding - patch_output_shape = A1.output_shape - patch_unstable_idx = A1.unstable_idx - # Patches has shape (out_c, batch, out_h, out_w, in_c, h, w). - input_shape = A2.shape[3:] - matrix = A2 - if type(A2) == Patches: - pieces = A2.patches - stride = A2.stride - padding = A2.padding - patch_output_shape = A2.output_shape - patch_unstable_idx = A2.unstable_idx - input_shape = A1.shape[3:] - matrix = A1 - A1_matrix = patches_to_matrix( - pieces, input_shape, stride, padding, output_shape=patch_output_shape, unstable_idx=patch_unstable_idx) - return A1_matrix.transpose(0,1) + matrix - - def _backward_general(self, C=None, node=None, root=None, bound_lower=True, bound_upper=True, - return_A=False, needed_A_dict=None, average_A=False, A_dict=None, return_b=False, b_dict=None, - intermediate_constr=None, unstable_idx=None, need_A_only=False): - logger.debug('Backward from ({})[{}]'.format(node, node.name)) - _print_time = False - - degree_out = {} - for l in self._modules.values(): - l.bounded = True - l.lA = l.uA = None - degree_out[l.name] = 0 - queue = deque([node]) - all_nodes_before = [] - while len(queue) > 0: - l = queue.popleft() - self.backward_from[l.name].append(node) - for l_pre in l.input_name: - all_nodes_before.append(l_pre) - degree_out[l_pre] += 1 # calculate the out degree - if self._modules[l_pre].bounded: - self._modules[l_pre].bounded = False - queue.append(self._modules[l_pre]) - node.bounded = True - if isinstance(C, Patches): - if C.unstable_idx is None: - # Patches have size (out_c, batch, out_h, out_w, c, h, w). - out_c, batch_size, out_h, out_w = C.shape[:4] - output_dim = out_c * out_h * out_w - else: - # Patches have size (unstable_size, batch, c, h, w). - output_dim, batch_size = C.shape[:2] - else: - batch_size, output_dim = C.shape[:2] - - # The C matrix specified by the user has shape (batch, spec) but internally we have (spec, batch) format. - if not isinstance(C, (eyeC, Patches, OneHotC)): - C = C.transpose(0, 1) - elif isinstance(C, eyeC): - C = C._replace(shape=(C.shape[1], C.shape[0], *C.shape[2:])) - elif isinstance(C, OneHotC): - C = C._replace(shape=(C.shape[1], C.shape[0], *C.shape[2:]), index=C.index.transpose(0,-1), coeffs=None if C.coeffs is None else C.coeffs.transpose(0,-1)) - - node.lA = C if bound_lower else None - node.uA = C if bound_upper else None - lb = ub = torch.tensor(0., device=self.device) - - - # Save intermediate layer A matrices when required. - A_record = {} - - queue = deque([node]) - while len(queue) > 0: - l = queue.popleft() # backward from l - l.bounded = True - - if return_b: - b_dict[l.name] = { 'lower_b': lb, 'upper_b': ub } - - if l.name in self.root_name or l == root: continue - - for l_pre in l.input_name: # if all the succeeds are done, then we can turn to this node in the next iteration. - _l = self._modules[l_pre] - degree_out[l_pre] -= 1 - if degree_out[l_pre] == 0: - queue.append(_l) - - # Initially, l.lA or l.uA will be set to C for this node. - if l.lA is not None or l.uA is not None: - # Propagate lA and uA to a preceding node - def add_bound(node, lA, uA): - if lA is not None: - if node.lA is None: - # First A added to this node. - node.zero_lA_mtx = l.zero_backward_coeffs_l - node.lA = lA - else: - node.zero_lA_mtx = node.zero_lA_mtx and l.zero_backward_coeffs_l - node.lA = self._addA(node.lA, lA) - if uA is not None: - if node.uA is None: - # First A added to this node. - node.zero_uA_mtx = l.zero_backward_coeffs_u - node.uA = uA - else: - node.zero_uA_mtx = node.zero_uA_mtx and l.zero_backward_coeffs_u - node.uA = self._addA(node.uA, uA) - - if _print_time: - start_time = time.time() - - # FIXME make fixed nodes have fixed `forward_value` that is never cleaned out - if not l.perturbed and hasattr(l, 'forward_value'): - lb = lb + l.get_bias(l.lA, l.forward_value) # FIXME (09/16): shape for the bias of BoundConstant. - ub = ub + l.get_bias(l.uA, l.forward_value) - continue - - if l.zero_uA_mtx and l.zero_lA_mtx: - # A matrices are all zero, no need to propagate. - continue - - if isinstance(l, BoundRelu): - A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *l.inputs, start_node=node, unstable_idx=unstable_idx, - beta_for_intermediate_layers=intermediate_constr is not None) # TODO: unify this interface. - elif isinstance(l, BoundOptimizableActivation): - A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *l.inputs, - start_shape=(prod(node.output_shape[1:]) if node.name != self.final_name - else C.shape[0]), start_node=node) - else: - A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *l.inputs) - - if _print_time: - time_elapsed = time.time() - start_time - if time_elapsed > 1e-3: - print(l, time_elapsed) - if lb.ndim > 0 and type(lower_b) == torch.Tensor and self.conv_mode == 'patches': - # When we use patches mode, it's possible that we need to add two bias - # one is from the Tensor mode and one is from the patches mode - # And we need to detect this case and reshape the bias - assert lower_b.ndim == 4 or lower_b.ndim == 2 - assert lb.ndim == 4 or lb.ndim == 2 - if lower_b.ndim < lb.ndim: - lb = lb.transpose(0,1).reshape(lb.size(1), lb.size(0), -1) - lb = lb.expand(lb.size(0), lb.size(1), lower_b.size(0)//lb.size(1)) - lb = lb.reshape(lb.size(0), -1).t() - ub = ub.transpose(0,1).reshape(ub.size(1), ub.size(0), -1) - ub = ub.expand(ub.size(0), ub.size(1), upper_b.size(0)//ub.size(1)) - ub = ub.reshape(ub.size(0), -1).t() - elif lower_b.ndim > lb.ndim: - lower_b = lower_b.transpose(0,1).reshape(lower_b.size(1), -1).t() - upper_b = upper_b.transpose(0,1).reshape(upper_b.size(1), -1).t() - lb = lb + lower_b - ub = ub + upper_b - if return_A and needed_A_dict and node.name in needed_A_dict: - # FIXME ??? - # if isinstance(self._modules[l.output_name[0]], BoundRelu): - # We save the A matrices after propagating through layer l if a ReLU follows l. - # Here we saved the A *after* propagating backwards through this layer. - # Note that we return the accumulated bias terms, to maintain linear relationship of this node (TODO: does not support ResNet). - if l.name in needed_A_dict[node.name]: - A_record.update({l.name: { - "lA": A[0][0].transpose(0, 1).detach() if A[0][0] is not None else None, - "uA": A[0][1].transpose(0, 1) .detach()if A[0][1] is not None else None, - "lbias": lb.transpose(0, 1).detach() if lb.ndim > 1 else None, - # When not used, lb or ub is tensor(0). - "ubias": ub.transpose(0, 1).detach() if ub.ndim > 1 else None, - }}) - A_dict.update({node.name: A_record}) - if need_A_only and set(needed_A_dict[node.name]) == set(A_record.keys()): - # We have collected all A matrices we need. We can return now! - A_dict.update({node.name: A_record}) - # Do not concretize to save time. We just need the A matrices. - # return A matrix as a dict: {node.name: [A_lower, A_upper]} - return None, None, A_dict - - - for i, l_pre in enumerate(l.input_name): - _l = self._modules[l_pre] - add_bound(_l, lA=A[i][0], uA=A[i][1]) - - if lb.ndim >= 2: - lb = lb.transpose(0, 1) - if ub.ndim >= 2: - ub = ub.transpose(0, 1) - - if return_A and needed_A_dict and node.name in needed_A_dict: - root_A_record = {} - for i in range(len(root)): - if root[i].lA is None and root[i].uA is None: continue - if root[i].name in needed_A_dict[node.name]: - root_A_record.update({root[i].name: { - "lA": root[i].lA.transpose(0, 1).detach() if root[i].lA is not None else None, - "uA": root[i].uA.transpose(0, 1).detach() if root[i].uA is not None else None, - }}) - root_A_record.update(A_record) # merge to existing A_record - A_dict.update({node.name: root_A_record}) - - for i in range(len(root)): - if root[i].lA is None and root[i].uA is None: continue - if average_A and isinstance(root[i], BoundParams): - lA = root[i].lA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].lA.shape) if bound_lower else None - uA = root[i].uA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].uA.shape) if bound_upper else None - else: - lA, uA = root[i].lA, root[i].uA - - if not isinstance(root[i].lA, eyeC) and not isinstance(root[i].lA, Patches): - lA = root[i].lA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_lower else None - if not isinstance(root[i].uA, eyeC) and not isinstance(root[i].uA, Patches): - uA = root[i].uA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_upper else None - if hasattr(root[i], 'perturbation') and root[i].perturbation is not None: - if isinstance(root[i], BoundParams): - # add batch_size dim for weights node - lb = lb + root[i].perturbation.concretize( - root[i].center.unsqueeze(0), lA, - sign=-1, aux=root[i].aux) if bound_lower else None - ub = ub + root[i].perturbation.concretize( - root[i].center.unsqueeze(0), uA, - sign=+1, aux=root[i].aux) if bound_upper else None - else: - lb = lb + root[i].perturbation.concretize(root[i].center, lA, sign=-1, - aux=root[i].aux) if bound_lower else None - ub = ub + root[i].perturbation.concretize(root[i].center, uA, sign=+1, - aux=root[i].aux) if bound_upper else None - else: - if i < self.num_global_inputs: - # Input node so there is a batch dimension - fv = root[i].forward_value.view(batch_size, -1, 1) - batch_size_ = batch_size - else: - # Parameter node so there is no batch dimension - fv = root[i].forward_value.view(-1, 1) - batch_size_ = 1 - if isinstance(lA, eyeC): - lb = lb + fv.view(batch_size_, -1) if bound_lower else None - else: - lb = lb + lA.matmul(fv).squeeze(-1) if bound_lower else None - if isinstance(uA, eyeC): - ub = ub + fv.view(batch_size_, -1) if bound_upper else None - else: - ub = ub + uA.matmul(fv).squeeze(-1) if bound_upper else None - - if isinstance(C, Patches) and C.unstable_idx is not None: - # Sparse patches; the output shape is (unstable_size, ). - output_shape = [C.shape[0]] - elif prod(node.output_shape[1:]) != output_dim and not isinstance(C, Patches): - # For the output node, the shape of the bound follows C - # instead of the original output shape - # TODO Maybe don't set node.lower and node.upper in this case? - # Currently some codes still depend on node.lower and node.upper - output_shape = [-1] - else: - # Generally, the shape of the bounds match the output shape of the node - output_shape = node.output_shape[1:] - lb = node.lower = lb.view(batch_size, *output_shape) if bound_lower else None - ub = node.upper = ub.view(batch_size, *output_shape) if bound_upper else None - - if return_A: return lb, ub, A_dict - - return lb, ub - - def _forward_general(self, C=None, node=None, root=None, dim_in=None, concretize=False): - if hasattr(node, 'lower'): - return node.lower, node.upper - - if not node.from_input: - w, b = None, node.value - node.linear = LinearBound(w, b, w, b, b, b) - node.lower = node.upper = b - node.interval = (node.lower, node.upper) - return node.interval - - if not hasattr(node, 'linear'): - for l_pre in node.input_name: - l = self._modules[l_pre] - if not hasattr(l, 'linear'): - self._forward_general(node=l, root=root, dim_in=dim_in) - - inp = [self._modules[l_pre].linear for l_pre in node.input_name] - - if C is not None and isinstance(node, BoundLinear) and not node.is_input_perturbed(1): - node.linear = node.bound_forward(dim_in, *inp, C=C) - C_merged = True - else: - node.linear = node.bound_forward(dim_in, *inp) - C_merged = False - - lw, uw = node.linear.lw, node.linear.uw - lower, upper = node.linear.lb, node.linear.ub - - if C is not None and not C_merged: - # FIXME use bound_forward of BoundLinear - C_pos, C_neg = C.clamp(min=0), C.clamp(max=0) - _lw = torch.matmul(lw, C_pos.transpose(-1, -2)) + torch.matmul(uw, C_neg.transpose(-1, -2)) - _uw = torch.matmul(uw, C_pos.transpose(-1, -2)) + torch.matmul(lw, C_neg.transpose(-1, -2)) - lw, uw = _lw, _uw - _lower = torch.matmul(lower.unsqueeze(1), C_pos.transpose(-1, -2)) + \ - torch.matmul(upper.unsqueeze(1), C_neg.transpose(-1, -2)) - _upper = torch.matmul(upper.unsqueeze(1), C_pos.transpose(-1, -2)) + \ - torch.matmul(lower.unsqueeze(1), C_neg.transpose(-1, -2)) - lower, upper = _lower.squeeze(1), _upper.squeeze(1) - else: - lw, uw = node.linear.lw, node.linear.uw - lower, upper = node.linear.lb, node.linear.ub - - if concretize: - if node.linear.lw is not None: - prev_dim_in = 0 - batch_size = lw.shape[0] - assert (lw.ndim > 1) - lA = lw.reshape(batch_size, dim_in, -1).transpose(1, 2) - uA = uw.reshape(batch_size, dim_in, -1).transpose(1, 2) - for i in range(len(root)): - if hasattr(root[i], 'perturbation') and root[i].perturbation is not None: - _lA = lA[:, :, prev_dim_in : (prev_dim_in + root[i].dim)] - _uA = uA[:, :, prev_dim_in : (prev_dim_in + root[i].dim)] - lower = lower + root[i].perturbation.concretize( - root[i].center, _lA, sign=-1, aux=root[i].aux).view(lower.shape) - upper = upper + root[i].perturbation.concretize( - root[i].center, _uA, sign=+1, aux=root[i].aux).view(upper.shape) - prev_dim_in += root[i].dim - if C is None: - node.linear = node.linear._replace(lower=lower, upper=upper) - if C is None: - node.lower, node.upper = lower, upper - if not Benchmarking and torch.isnan(lower).any(): - import pdb - pdb.set_trace() - return lower, upper - - def _init_forward(self, root, dim_in): - if dim_in == 0: - raise ValueError("At least one node should have a specified perturbation") - prev_dim_in = 0 - # Assumption: root[0] is the input node which implies batch_size - batch_size = root[0].value.shape[0] - for i in range(len(root)): - if hasattr(root[i], 'perturbation') and root[i].perturbation is not None: - shape = root[i].linear.lw.shape - device = root[i].linear.lw.device - dtype = root[i].linear.lw.dtype - root[i].linear = root[i].linear._replace( - lw=torch.cat([ - torch.zeros(shape[0], prev_dim_in, *shape[2:], device=device, dtype=dtype), - root[i].linear.lw, - torch.zeros(shape[0], dim_in - shape[1], *shape[2:], device=device, dtype=dtype) - ], dim=1), - uw=torch.cat([ - torch.zeros(shape[0], prev_dim_in, *shape[2:], device=device, dtype=dtype), - root[i].linear.uw, - torch.zeros(shape[0], dim_in - shape[1] - prev_dim_in, *shape[2:], device=device, dtype=dtype) - ], dim=1) - ) - if i >= self.num_global_inputs: - root[i].forward_value = root[i].forward_value.unsqueeze(0).repeat( - *([batch_size] + [1] * self.forward_value.ndim)) - prev_dim_in += shape[1] - else: - b = fv = root[i].forward_value - shape = fv.shape - if root[i].from_input: - w = torch.zeros(shape[0], dim_in, *shape[1:], device=self.device) - else: - w = None - root[i].linear = LinearBound(w, b, w, b, b, b) - root[i].lower = root[i].upper = b - root[i].interval = (root[i].lower, root[i].upper) - - """Add perturbation to an intermediate node and it is treated as an independent - node in bound computation.""" - def add_intermediate_perturbation(self, node, perturbation): + """Add perturbation to an intermediate node and it is treated as an + independent node in bound computation.""" node.perturbation = perturbation node.perturbed = True # NOTE This change is currently inreversible if not node.name in self.root_name: self.root_name.append(node.name) - - -class BoundDataParallel(DataParallel): - # https://github.com/huanzhang12/CROWN-IBP/blob/master/bound_layers.py - # This is a customized DataParallel class for our project - def __init__(self, *inputs, **kwargs): - super(BoundDataParallel, self).__init__(*inputs, **kwargs) - self._replicas = None - - # Overide the forward method - def forward(self, *inputs, **kwargs): - disable_multi_gpu = False # forward by single GPU - no_replicas = False # forward by multi GPUs but without replicate - if "disable_multi_gpu" in kwargs: - disable_multi_gpu = kwargs["disable_multi_gpu"] - kwargs.pop("disable_multi_gpu") - - if "no_replicas" in kwargs: - no_replicas = kwargs["no_replicas"] - kwargs.pop("no_replicas") - - if not self.device_ids or disable_multi_gpu: - if kwargs.pop("get_property", False): - return self.get_property(self, *inputs, **kwargs) - return self.module(*inputs, **kwargs) - - if kwargs.pop("get_property", False): - if self._replicas is None: - assert 0, 'please call IBP/CROWN before get_property' - if len(self.device_ids) == 1: - return self.get_property(self.module, **kwargs) - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - kwargs = list(kwargs) - for i in range(len(kwargs)): - kwargs[i]['model'] = self._replicas[i] - outputs = self.parallel_apply([self.get_property] * len(kwargs), inputs, kwargs) - return self.gather(outputs, self.output_device) - - # Only replicate during forward/IBP propagation. Not during interval bounds - # and CROWN-IBP bounds, since weights have not been updated. This saves 2/3 - # of communication cost. - if not no_replicas: - if self._replicas is None: # first time - self._replicas = self.replicate(self.module, self.device_ids) - elif kwargs.get("method_opt", "forward") == "forward": - self._replicas = self.replicate(self.module, self.device_ids) - elif kwargs.get("x") is not None and kwargs.get("IBP") is True: # - self._replicas = self.replicate(self.module, self.device_ids) - # Update the input nodes to the ones within each replica respectively - for bounded_module in self._replicas: - for node in bounded_module._modules.values(): - node.inputs = [bounded_module._modules[name] for name in node.input_name] - - for t in chain(self.module.parameters(), self.module.buffers()): - if t.device != self.src_device_obj: - raise RuntimeError("module must have its parameters and buffers " - "on device {} (device_ids[0]) but found one of " - "them on device: {}".format(self.src_device_obj, t.device)) - - # TODO: can be done in parallel, only support same ptb for all inputs per forward/IBP propagation - if len(inputs) > 0 and hasattr(inputs[0], 'ptb') and inputs[0].ptb is not None: - # compute bounds without x - # inputs_scatter is a normal tensor, we need to assign ptb to it if inputs is a BoundedTensor - inputs_scatter, kwargs = self.scatter((inputs, inputs[0].ptb.x_L, inputs[0].ptb.x_U), kwargs, - self.device_ids) - # inputs_scatter = inputs_scatter[0] - bounded_inputs = [] - for input_s in inputs_scatter: # GPU numbers - ptb = PerturbationLpNorm(norm=inputs[0].ptb.norm, eps=inputs[0].ptb.eps, x_L=input_s[1], x_U=input_s[2]) - # bounded_inputs.append(tuple([(BoundedTensor(input_s[0][0], ptb))])) - input_s = list(input_s[0]) - input_s[0] = BoundedTensor(input_s[0], ptb) - input_s = tuple(input_s) - bounded_inputs.append(input_s) - - # bounded_inputs = tuple(bounded_inputs) - elif kwargs.get("x") is not None and hasattr(kwargs.get("x")[0], 'ptb') and kwargs.get("x")[0].ptb is not None: - # compute bounds with x - # kwargs['x'] is a normal tensor, we need to assign ptb to it - x = kwargs.get("x")[0] - bounded_inputs = [] - inputs_scatter, kwargs = self.scatter((inputs, x.ptb.x_L, x.ptb.x_U), kwargs, self.device_ids) - for input_s, kw_s in zip(inputs_scatter, kwargs): # GPU numbers - ptb = PerturbationLpNorm(norm=x.ptb.norm, eps=x.ptb.eps, x_L=input_s[1], x_U=input_s[2]) - kw_s['x'] = list(kw_s['x']) - kw_s['x'][0] = BoundedTensor(kw_s['x'][0], ptb) - kw_s['x'] = (kw_s['x']) - bounded_inputs.append(tuple(input_s[0], )) - else: - # normal forward - inputs_scatter, kwargs = self.scatter(inputs, kwargs, self.device_ids) - bounded_inputs = inputs_scatter - - if len(self.device_ids) == 1: - return self.module(*bounded_inputs[0], **kwargs[0]) - outputs = self.parallel_apply(self._replicas[:len(bounded_inputs)], bounded_inputs, kwargs) - return self.gather(outputs, self.output_device) - - @staticmethod - def get_property(model, node_class=None, att_name=None, node_name=None): - if node_name: - # Find node by name - # FIXME If we use `model.named_modules()`, the nodes have the - # `BoundedModule` type rather than bound nodes. - for node in model._modules.values(): - if node.name == node_name: - return getattr(node, att_name) - else: - # Find node by class - for _, node in model.named_modules(): - # Find the Exp neuron in computational graph - if isinstance(node, node_class): - return getattr(node, att_name) - - def state_dict(self, destination=None, prefix='', keep_vars=False): - # add 'module.' here before each keys in self.module.state_dict() if needed - return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - - def _named_members(self, get_members_fn, prefix='', recurse=True): - return self.module._named_members(get_members_fn, prefix, recurse) - + from .interval_bound import ( + IBP_general, _IBP_loss_fusion, check_IBP_intermediate, + check_IBP_first_linear) + from .forward_bound import ( + forward_general, forward_general_dynamic, init_forward) + from .backward_bound import ( + backward_general, get_sparse_C, check_optimized_variable_sparsity, + restore_sparse_bounds, get_alpha_crown_start_nodes, + get_unstable_locations, batched_backward) + from .optimized_bounds import get_optimized_bounds, init_slope + from .beta_crown import ( + beta_bias, save_best_intermediate_betas, + print_optimized_beta) + + + from .solver_module import build_solver_module, _build_solver_input, _build_solver_general diff --git a/auto_LiRPA/bound_multi_gpu.py b/auto_LiRPA/bound_multi_gpu.py new file mode 100644 index 0000000..057ef42 --- /dev/null +++ b/auto_LiRPA/bound_multi_gpu.py @@ -0,0 +1,127 @@ +from torch.nn import DataParallel +from .perturbations import * +from .bounded_tensor import BoundedTensor +from itertools import chain + +class BoundDataParallel(DataParallel): + # https://github.com/huanzhang12/CROWN-IBP/blob/master/bound_layers.py + # This is a customized DataParallel class for our project + def __init__(self, *inputs, **kwargs): + super(BoundDataParallel, self).__init__(*inputs, **kwargs) + self._replicas = None + + # Overide the forward method + def forward(self, *inputs, **kwargs): + disable_multi_gpu = False # forward by single GPU + no_replicas = False # forward by multi GPUs but without replicate + if "disable_multi_gpu" in kwargs: + disable_multi_gpu = kwargs["disable_multi_gpu"] + kwargs.pop("disable_multi_gpu") + + if "no_replicas" in kwargs: + no_replicas = kwargs["no_replicas"] + kwargs.pop("no_replicas") + + if not self.device_ids or disable_multi_gpu: + if kwargs.pop("get_property", False): + return self.get_property(self, *inputs, **kwargs) + return self.module(*inputs, **kwargs) + + if kwargs.pop("get_property", False): + if self._replicas is None: + assert 0, 'please call IBP/CROWN before get_property' + if len(self.device_ids) == 1: + return self.get_property(self.module, **kwargs) + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + kwargs = list(kwargs) + for i in range(len(kwargs)): + kwargs[i]['model'] = self._replicas[i] + outputs = self.parallel_apply([self.get_property] * len(kwargs), inputs, kwargs) + return self.gather(outputs, self.output_device) + + # Only replicate during forward/IBP propagation. Not during interval bounds + # and CROWN-IBP bounds, since weights have not been updated. This saves 2/3 + # of communication cost. + if not no_replicas: + if self._replicas is None: # first time + self._replicas = self.replicate(self.module, self.device_ids) + elif kwargs.get("method_opt", "forward") == "forward": + self._replicas = self.replicate(self.module, self.device_ids) + elif kwargs.get("x") is not None and kwargs.get("IBP") is True: # + self._replicas = self.replicate(self.module, self.device_ids) + # Update the input nodes to the ones within each replica respectively + for bounded_module in self._replicas: + for node in bounded_module._modules.values(): + node.inputs = [bounded_module[name] for name in node.input_name] + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError("module must have its parameters and buffers " + "on device {} (device_ids[0]) but found one of " + "them on device: {}".format(self.src_device_obj, t.device)) + + # TODO: can be done in parallel, only support same ptb for all inputs per forward/IBP propagation + if len(inputs) > 0 and hasattr(inputs[0], 'ptb') and inputs[0].ptb is not None: + # compute bounds without x + # inputs_scatter is a normal tensor, we need to assign ptb to it if inputs is a BoundedTensor + inputs_scatter, kwargs = self.scatter((inputs, inputs[0].ptb.x_L, inputs[0].ptb.x_U), kwargs, + self.device_ids) + # inputs_scatter = inputs_scatter[0] + bounded_inputs = [] + for input_s in inputs_scatter: # GPU numbers + # FIXME other perturbations are not supported yet + assert isinstance(inputs[0].ptb, PerturbationLpNorm) + ptb = PerturbationLpNorm(norm=inputs[0].ptb.norm, eps=inputs[0].ptb.eps, x_L=input_s[1], x_U=input_s[2]) + input_s = list(input_s[0]) + input_s[0] = BoundedTensor(input_s[0], ptb) + input_s = tuple(input_s) + bounded_inputs.append(input_s) + + # bounded_inputs = tuple(bounded_inputs) + elif kwargs.get("x") is not None and hasattr(kwargs.get("x")[0], 'ptb') and kwargs.get("x")[0].ptb is not None: + # compute bounds with x + # kwargs['x'] is a normal tensor, we need to assign ptb to it + x = kwargs.get("x")[0] + bounded_inputs = [] + inputs_scatter, kwargs = self.scatter((inputs, x.ptb.x_L, x.ptb.x_U), kwargs, self.device_ids) + for input_s, kw_s in zip(inputs_scatter, kwargs): # GPU numbers + # FIXME other perturbations are not supported yet + assert isinstance(x.ptb, PerturbationLpNorm) + ptb = PerturbationLpNorm(norm=x.ptb.norm, eps=x.ptb.eps, x_L=input_s[1], x_U=input_s[2]) + kw_s['x'] = list(kw_s['x']) + kw_s['x'][0] = BoundedTensor(kw_s['x'][0], ptb) + kw_s['x'] = (kw_s['x']) + bounded_inputs.append(tuple(input_s[0], )) + else: + # normal forward + inputs_scatter, kwargs = self.scatter(inputs, kwargs, self.device_ids) + bounded_inputs = inputs_scatter + + if len(self.device_ids) == 1: + return self.module(*bounded_inputs[0], **kwargs[0]) + outputs = self.parallel_apply(self._replicas[:len(bounded_inputs)], bounded_inputs, kwargs) + return self.gather(outputs, self.output_device) + + @staticmethod + def get_property(model, node_class=None, att_name=None, node_name=None): + if node_name: + # Find node by name + # FIXME If we use `model.named_modules()`, the nodes have the + # `BoundedModule` type rather than bound nodes. + for node in model._modules.values(): + if node.name == node_name: + return getattr(node, att_name) + else: + # Find node by class + for _, node in model.named_modules(): + # Find the Exp neuron in computational graph + if isinstance(node, node_class): + return getattr(node, att_name) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + # add 'module.' here before each keys in self.module.state_dict() if needed + return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + def _named_members(self, get_members_fn, prefix='', recurse=True): + return self.module._named_members(get_members_fn, prefix, recurse) + diff --git a/auto_LiRPA/bound_op_map.py b/auto_LiRPA/bound_op_map.py index 087dab3..c33d52e 100644 --- a/auto_LiRPA/bound_op_map.py +++ b/auto_LiRPA/bound_op_map.py @@ -6,19 +6,7 @@ } def register_custom_op(op_name: str, bound_obj: Bound) -> None: - """ Register a custom operator. - - Args: - op_name (str): Name of the custom operator - - bound_obj (Bound): The corresponding Bound class for the operator. - """ bound_op_map[op_name] = bound_obj def unregister_custom_op(op_name: str) -> None: - """ Unregister a custom operator. - - Args: - op_name (str): Name of the custom operator - """ - bound_op_map.pop(op_name) \ No newline at end of file + bound_op_map.pop(op_name) diff --git a/auto_LiRPA/bounded_tensor.py b/auto_LiRPA/bounded_tensor.py index ce922d9..cb3302f 100644 --- a/auto_LiRPA/bounded_tensor.py +++ b/auto_LiRPA/bounded_tensor.py @@ -1,3 +1,4 @@ +import copy import torch import torch.nn as nn from torch import Tensor as Tensor @@ -26,7 +27,7 @@ def __repr__(self): return ''.format(super().__repr__()) def clone(self, *args, **kwargs): - tensor = BoundedTensor(super().clone(*args, **kwargs), self.ptb) + tensor = BoundedTensor(super().clone(*args, **kwargs), copy.deepcopy(self.ptb)) return tensor def _func(self, func, *args, **kwargs): @@ -38,6 +39,13 @@ def _func(self, func, *args, **kwargs): # Copy to other devices with perturbation def to(self, *args, **kwargs): + # FIXME add a general "to" function in perturbation class, not here. + if hasattr(self.ptb, 'x_L') and isinstance(self.ptb.x_L, Tensor): + self.ptb.x_L = self.ptb.x_L.to(*args, **kwargs) + if hasattr(self.ptb, 'x_U') and isinstance(self.ptb.x_U, Tensor): + self.ptb.x_U = self.ptb.x_U.to(*args, **kwargs) + if hasattr(self.ptb, 'eps') and isinstance(self.ptb.eps, Tensor): + self.ptb.eps = self.ptb.eps.to(*args, **kwargs) return self._func(super().to, *args, **kwargs) @classmethod diff --git a/auto_LiRPA/cuda/cuda_kernels.cu b/auto_LiRPA/cuda/cuda_kernels.cu new file mode 100644 index 0000000..a67a30b --- /dev/null +++ b/auto_LiRPA/cuda/cuda_kernels.cu @@ -0,0 +1,40 @@ +#include + +#include +#include + +#include + +__global__ void cuda_double2float_rd_kernel(const double* __restrict__ inputs, + float* __restrict__ outputs, const size_t tensor_size) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < tensor_size) { + outputs[idx] = __double2float_rd(inputs[idx]); + } +} + +__global__ void cuda_double2float_ru_kernel(const double* __restrict__ inputs, + float* __restrict__ outputs, const size_t tensor_size) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < tensor_size) { + outputs[idx] = __double2float_ru(inputs[idx]); + } +} + +torch::Tensor cuda_double2float_forward(torch::Tensor input, + const std::string direction) { + auto total_elem = input.numel(); + auto output = torch::empty_like(input, torch::ScalarType::Float); + + const int threads = 1024; + const int blocks = (total_elem + threads - 1) / threads; + + if (direction == "down") { + cuda_double2float_rd_kernel<<>>(input.data(), output.data(), total_elem); + } + else { + cuda_double2float_ru_kernel<<>>(input.data(), output.data(), total_elem); + } + return output; +} + diff --git a/auto_LiRPA/cuda/cuda_utils.cpp b/auto_LiRPA/cuda/cuda_utils.cpp new file mode 100644 index 0000000..867c910 --- /dev/null +++ b/auto_LiRPA/cuda/cuda_utils.cpp @@ -0,0 +1,25 @@ +#include + +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") + +torch::Tensor cuda_double2float_forward( + torch::Tensor input, const std::string direction); + +torch::Tensor double2float_foward( + torch::Tensor input, const std::string direction) { + TORCH_CHECK((direction == "down") || (direction == "up"), "Unsupported direction, must be down or up."); + TORCH_CHECK(input.type().scalarType() == torch::ScalarType::Double, "This function only supports DoubleTensor as inputs."); + CHECK_CUDA(input); + return cuda_double2float_forward(input, direction); +} + +/* + * Usage: double2float(tensor, direction) + * "tensor" must be a DoubleTensor on GPU. + * "direction" is a string, can be "up" (round up) or "down" (round down). + */ +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("double2float", &double2float_foward, "Convert double to float with rounding direction control (direction = 'up' or 'down')."); +} diff --git a/auto_LiRPA/cuda_utils.py b/auto_LiRPA/cuda_utils.py new file mode 100644 index 0000000..912ac6c --- /dev/null +++ b/auto_LiRPA/cuda_utils.py @@ -0,0 +1,126 @@ +import os +import sys +import time +import torch +from torch.utils.cpp_extension import load, BuildExtension, CUDAExtension +from setuptools import setup + +class DummyCudaClass: + """A dummy class with error message when a CUDA function is called.""" + def __getattr__(self, attr): + if attr == "double2float": + # When CUDA module is not built successfully, use a workaround. + def _f(x, d): + print('WARNING: Missing CUDA kernels. Please enable CUDA build by setting environment variable AUTOLIRPA_ENABLE_CUDA_BUILD=1 for the correct behavior!') + return x.float() + return _f + def _f(*args, **kwargs): + raise RuntimeError(f"method {attr} not available because CUDA module was not built.") + return _f + +if __name__ == "__main__" and len(sys.argv) > 1: + # Build and install native CUDA modules that can be directly imported later + print('Building and installing native CUDA modules...') + setup( + name='auto_LiRPA_cuda_utils', + ext_modules=[CUDAExtension('auto_LiRPA_cuda_utils', [ + 'auto_LiRPA/cuda/cuda_utils.cpp', + 'auto_LiRPA/cuda/cuda_kernels.cu' + ])], + cmdclass={'build_ext': BuildExtension.with_options()}, + ) + exit(0) + +if torch.cuda.is_available() and os.environ.get('AUTOLIRPA_ENABLE_CUDA_BUILD', False): + try: + import auto_LiRPA_cuda_utils as _cuda_utils + except: + print('CUDA modules have not been installed') + try: + print('Building native CUDA modules...') + code_dir = os.path.dirname(os.path.abspath(__file__)) + verbose = os.environ.get('AUTOLIRPA_DEBUG_CUDA_BUILD', None) is not None + _cuda_utils = load( + 'cuda_utils', [os.path.join(code_dir, 'cuda', 'cuda_utils.cpp'), os.path.join(code_dir, 'cuda', 'cuda_kernels.cu')], verbose=verbose) + print('CUDA modules have been built.') + except: + print('CUDA module build failure. Some features will be unavailable.') + print('Please make sure the latest CUDA toolkit is installed in your system.') + if verbose: + print(sys.exc_info()[2]) + else: + print('Set environment variable AUTOLIRPA_DEBUG_CUDA_BUILD=1 to view build log.') + _cuda_utils = DummyCudaClass() +else: + if os.environ.get('AUTOLIRPA_ENABLE_CUDA_BUILD', False): + print('CUDA unavailable. Some features are disabled.') + _cuda_utils = DummyCudaClass() + +double2float = _cuda_utils.double2float + +def test_double2float(): + # Test the double2float function. + import time + shape = (3,4,5) + + a = torch.randn(size=shape, dtype=torch.float64, device='cuda') + a = a.transpose(0,1) + + au = _cuda_utils.double2float(a, "up") + ad = _cuda_utils.double2float(a, "down") + + print(a.size(), au.size(), ad.size()) + + a_flatten = a.reshape(-1) + au_flatten = au.reshape(-1) + ad_flatten = ad.reshape(-1) + + for i in range(a_flatten.numel()): + ai = a_flatten[i].item() + aui = au_flatten[i].item() + adi = ad_flatten[i].item() + print(adi, ai, aui) + assert adi <= ai + assert aui >= ai + del a, au, ad, a_flatten, au_flatten, ad_flatten + + # Performance benchmark. + for j in [1, 4, 16, 64, 256, 1024]: + shape = (j, 512, 1024) + print(f'shape: {shape}') + t = torch.randn(size=shape, dtype=torch.float64, device='cuda') + + torch.cuda.synchronize() + start_time = time.time() + for i in range(10): + tt = t.float() + torch.cuda.synchronize() + del tt + pytorch_time = time.time() - start_time + print(f'pytorch rounding time: {pytorch_time:.4f}') + + torch.cuda.synchronize() + start_time = time.time() + for i in range(10): + tu = _cuda_utils.double2float(t, "up") + torch.cuda.synchronize() + del tu + roundup_time = time.time() - start_time + print(f'cuda round up time: {roundup_time:.4f}') + + torch.cuda.synchronize() + start_time = time.time() + for i in range(10): + td = _cuda_utils.double2float(t, "down") + torch.cuda.synchronize() + del td + rounddown_time = time.time() - start_time + print(f'cuda round down time: {rounddown_time:.4f}') + + del t + + +if __name__ == "__main__": + if len(sys.argv) == 1: + # Some tests. It's not possible to test them automatically because travis does not have CUDA. + test_double2float() diff --git a/auto_LiRPA/forward_bound.py b/auto_LiRPA/forward_bound.py new file mode 100644 index 0000000..a090208 --- /dev/null +++ b/auto_LiRPA/forward_bound.py @@ -0,0 +1,306 @@ +from auto_LiRPA.beta_crown import print_optimized_beta +import torch +from torch import Tensor +import warnings +from .bound_ops import * +from .utils import * +from .backward_bound import batched_backward +from .linear_bound import LinearBound +from .perturbations import PerturbationLpNorm + +import sys +sys.setrecursionlimit(1000000) + +def forward_general(self, C=None, node=None, concretize=False, offset=0): + if self.bound_opts['dynamic_forward']: + return self.forward_general_dynamic(C, node, concretize, offset) + + if C is None: + if hasattr(node, 'linear'): + return node.linear.lower, node.linear.upper + if not node.from_input: + node.linear = LinearBound(None, node.value, None, node.value, node.value, node.value) + return node.value, node.value + if not node.perturbed: + node.lower = node.upper = self.get_forward_value(node) + if hasattr(node, 'lower'): + node.linear = LinearBound(None, node.lower, None, node.upper, node.lower, node.upper) + return node.lower, node.upper + + for l_pre in node.inputs: + if not hasattr(l_pre, 'linear'): + self.forward_general(node=l_pre, offset=offset) + inp = [l_pre.linear for l_pre in node.inputs] + node._start = '_forward' + if (C is not None and isinstance(node, BoundLinear) and + not node.is_input_perturbed(1) and not node.is_input_perturbed(2)): + linear = node.bound_forward(self.dim_in, *inp, C=C) + C_merged = True + else: + linear = node.linear = node.bound_forward(self.dim_in, *inp) + C_merged = False + + lw, uw = linear.lw, linear.uw + lower, upper = linear.lb, linear.ub + + if C is not None and not C_merged: + # FIXME use bound_forward of BoundLinear + C_pos, C_neg = C.clamp(min=0), C.clamp(max=0) + _lw = torch.matmul(lw, C_pos.transpose(-1, -2)) + torch.matmul(uw, C_neg.transpose(-1, -2)) + _uw = torch.matmul(uw, C_pos.transpose(-1, -2)) + torch.matmul(lw, C_neg.transpose(-1, -2)) + lw, uw = _lw, _uw + _lower = torch.matmul(lower.unsqueeze(1), C_pos.transpose(-1, -2)) + \ + torch.matmul(upper.unsqueeze(1), C_neg.transpose(-1, -2)) + _upper = torch.matmul(upper.unsqueeze(1), C_pos.transpose(-1, -2)) + \ + torch.matmul(lower.unsqueeze(1), C_neg.transpose(-1, -2)) + lower, upper = _lower.squeeze(1), _upper.squeeze(1) + + logger.debug(f'Forward bounds to {node}') + + if concretize: + if lw is not None or uw is not None: + prev_dim_in = 0 + batch_size = lw.shape[0] + assert (lw.ndim > 1) + lA = lw.reshape(batch_size, self.dim_in, -1).transpose(1, 2) + uA = uw.reshape(batch_size, self.dim_in, -1).transpose(1, 2) + for i in range(len(self.root)): + if hasattr(self.root[i], 'perturbation') and self.root[i].perturbation is not None: + _lA = lA[:, :, prev_dim_in : (prev_dim_in + self.root[i].dim)] + _uA = uA[:, :, prev_dim_in : (prev_dim_in + self.root[i].dim)] + lower = lower + self.root[i].perturbation.concretize( + self.root[i].center, _lA, sign=-1, aux=self.root[i].aux).view(lower.shape) + upper = upper + self.root[i].perturbation.concretize( + self.root[i].center, _uA, sign=+1, aux=self.root[i].aux).view(upper.shape) + prev_dim_in += self.root[i].dim + linear.lower, linear.upper = lower, upper + + if C is None: + node.linear = linear + node.lower, node.upper = lower, upper + + if self.bound_opts['forward_refinement']: + need_refinement = False + for out in node.output_name: + out_node = self[out] + if getattr(out_node, 'nonlinear', False): + need_refinement = True + for i in getattr(out_node, 'requires_input_bounds', []): + if out_node.inputs[i] == node: + need_refinement = True + break + if need_refinement: + forward_refinement(self, node) + return lower, upper + + +def forward_general_dynamic( + self, C=None, node=None, concretize=False, offset=0): + max_dim = self.bound_opts['forward_max_dim'] + + if C is None: + if hasattr(node, 'linear'): + assert not concretize + + linear = node.linear + if offset == 0: + if linear.lw is None: + return linear + elif linear.lw.shape[1] <= max_dim: + return linear + if linear.lw is not None: + lw = linear.lw[:, offset:offset+max_dim] + x_L = linear.x_L[:, offset:offset+max_dim] + x_U = linear.x_U[:, offset:offset+max_dim] + tot_dim = linear.tot_dim + if offset == 0: + lb = linear.lb + else: + lb = torch.zeros_like(linear.lb) + else: + lw = x_L = x_U = None + tot_dim = 0 + lb = linear.lb + return LinearBound( + lw, lb, lw, lb, x_L=x_L, x_U=x_U, + offset=offset, tot_dim=tot_dim, + ) + + # These cases have no coefficient tensor + if not node.from_input: + if concretize: + return node.value, node.value + else: + node.linear = LinearBound( + None, node.value, None, node.value, node.value, node.value) + return node.linear + if not node.perturbed: + if not hasattr(node, 'lower'): + node.lower = node.upper = self.get_forward_value(node) + raise NotImplementedError + if concretize: + return node.lower, node.upper + else: + if offset > 0: + lb = torch.zeros_like(node.lower) + else: + lb = node.lower + node.linear = LinearBound(None, lb, None, lb, node.lower, node.upper) + return node.linear + + if offset == 0: + logger.debug(f'forward_general_dynamic: node={node}') + + inp = [] + for l_pre in node.inputs: + linear_inp = self.forward_general_dynamic(node=l_pre, offset=offset) + linear_inp.lower = getattr(l_pre, 'lower', None) + linear_inp.upper = getattr(l_pre, 'upper', None) + inp.append(linear_inp) + node._start = '_forward' + if (C is not None and isinstance(node, BoundLinear) and + not node.is_input_perturbed(1) and not node.is_input_perturbed(2)): + linear = node.bound_dynamic_forward( + *inp, C=C, max_dim=max_dim, offset=offset) + C_merged = True + else: + linear = node.bound_dynamic_forward( + *inp, max_dim=max_dim, offset=offset) + C_merged = False + if offset > 0: + linear.lb = linear.ub = torch.zeros_like(linear.lb) + + lw, lb, tot_dim = linear.lw, linear.lb, linear.tot_dim + #logger.debug(f'forward_general_dynamic: node={node}, w_size={lw.shape[1]}, tot_dim={tot_dim}') + + if C is not None and not C_merged: + # FIXME use bound_forward of BoundLinear + lw = torch.matmul(lw, C.transpose(-1, -2)) + lb = torch.matmul(lb.unsqueeze(1), C.transpose(-1, -2)).squeeze(1) + + if concretize: + lower = upper = lb + if lw is not None: + batch_size = lw.shape[0] + assert (lw.ndim > 1) + if lw.shape[1] > 0: + A = lw.reshape(batch_size, lw.shape[1], -1).transpose(1, 2) + ptb = PerturbationLpNorm(x_L=linear.x_L, x_U=linear.x_U) + lower = lower + ptb.concretize(x=None, A=A, sign=-1).view(lb.shape) + upper = upper + ptb.concretize(x=None, A=A, sign=1).view(lb.shape) + offset_next = offset + max_dim + more = offset_next < tot_dim + else: + more = False + + if C is None and offset == 0 and not more: + node.linear = linear + + if more: + if lw is not None and lw.shape[1] > 0: + del A + del ptb + del lw + del linear + del inp + # TODO make it non-recursive + lower_next, upper_next = self.forward_general_dynamic( + C, node, concretize=True, offset=offset_next) + lower = lower + lower_next + upper = upper + upper_next + + if C is None: + node.lower, node.upper = lower, upper + + return lower, upper + else: + return linear + + +def clean_memory(self, node): + """ Remove linear bounds that are no longer needed. """ + # TODO add an option to retain these bounds + + for inp in node.inputs: + if hasattr(inp, 'linear') and inp.linear is not None: + clean = True + for out in inp.output_name: + out_node = self[out] + if not (hasattr(out_node, 'linear') and out_node.linear is not None): + clean = False + if clean: + if isinstance(inp.linear, tuple): + for item in inp.linear: + del item + delattr(inp, 'linear') + + +def forward_refinement(self, node): + """ Refine forward bounds with backward bound propagation + (only refine unstable positions). """ + unstable_size_before = torch.logical_and(node.lower < 0, node.upper > 0).sum() + if unstable_size_before == 0: + return + unstable_idx, unstable_size = self.get_unstable_locations( + node.lower, node.upper, conv=isinstance(node, BoundConv)) + logger.debug(f'Forward refinement for {node}') + batch_size = node.lower.shape[0] + ret = self.batched_backward( + node, C=None, unstable_idx=unstable_idx, batch_size=batch_size) + self.restore_sparse_bounds( + node, unstable_idx, unstable_size, node.lower, node.upper, + new_lower=ret[0], new_upper=ret[1]) + unstable_size_after = torch.logical_and(node.lower < 0, node.upper > 0).sum() + logger.debug(f' Unstable neurons: {unstable_size_before} -> {unstable_size_after}') + # TODO also update linear bounds? + + +def init_forward(self, root, dim_in): + if dim_in == 0: + raise ValueError("At least one node should have a specified perturbation") + prev_dim_in = 0 + # Assumption: root[0] is the input node which implies batch_size + batch_size = root[0].value.shape[0] + dynamic = self.bound_opts['dynamic_forward'] + for i in range(len(root)): + if hasattr(root[i], 'perturbation') and root[i].perturbation is not None: + shape = root[i].linear.lw.shape + if dynamic: + if shape[1] != dim_in: + raise NotImplementedError('Dynamic forward bound is not supported yet when there are multiple perturbed inputs.') + ptb = root[i].perturbation + if (type(ptb) != PerturbationLpNorm or ptb.norm < np.inf + or ptb.x_L is None or ptb.x_U is None): + raise NotImplementedError( + 'For dynamic forward bounds, only Linf (box) perturbations are supported, and x_L and x_U must be explicitly provided.') + root[i].linear.x_L = ( + ptb.x_L_sparse.view(batch_size, -1) if ptb.sparse + else ptb.x_L.view(batch_size, -1)) + root[i].linear.x_U = ( + ptb.x_U_sparse.view(batch_size, -1) if ptb.sparse + else ptb.x_U.view(batch_size, -1)) + else: + lw = torch.zeros(shape[0], dim_in, *shape[2:]).to(root[i].linear.lw) + lw[:, prev_dim_in:(prev_dim_in+shape[1])] = root[i].linear.lw + if root[i].linear.lw.data_ptr() == root[i].linear.uw.data_ptr(): + uw = lw + else: + uw = torch.zeros(shape[0], dim_in, *shape[2:]).to(root[i].linear.uw) + uw[:, prev_dim_in:(prev_dim_in+shape[1])] = root[i].linear.uw + root[i].linear.lw = lw + root[i].linear.uw = uw + if i >= self.num_global_inputs: + root[i].forward_value = root[i].forward_value.unsqueeze(0).repeat( + *([batch_size] + [1] * self.forward_value.ndim)) + prev_dim_in += shape[1] + else: + b = fv = root[i].forward_value + shape = fv.shape + if root[i].from_input: + w = torch.zeros(shape[0], dim_in, *shape[1:], device=self.device) + warnings.warn(f'Creating a LinearBound with zero weights with shape {w.shape}') + else: + w = None + root[i].linear = LinearBound(w, b, w, b, b, b) + root[i].lower = root[i].upper = b + root[i].interval = (root[i].lower, root[i].upper) diff --git a/auto_LiRPA/interval_bound.py b/auto_LiRPA/interval_bound.py new file mode 100644 index 0000000..0fd09c4 --- /dev/null +++ b/auto_LiRPA/interval_bound.py @@ -0,0 +1,147 @@ +import torch +from .bound_ops import * + + +def IBP_general(self, node=None, C=None, delete_bounds_after_use=False): + + def _delete_unused_bounds(node_list): + """Delete bounds from input layers after use to save memory. Used when + sparse_intermediate_bounds_with_ibp is true.""" + if delete_bounds_after_use: + for n in node_list: + del n.interval + del n.lower + del n.upper + + if self.bound_opts.get('loss_fusion', False): + res = self._IBP_loss_fusion(node, C) + if res is not None: + return res + + if not node.perturbed and hasattr(node, 'forward_value'): + node.lower, node.upper = node.interval = ( + node.forward_value, node.forward_value) + + to_be_deleted_bounds = [] + if not hasattr(node, 'interval'): + for n in node.inputs: + if not hasattr(n, 'interval'): + # Node n does not have interval bounds; we must compute it. + self.IBP_general( + n, delete_bounds_after_use=delete_bounds_after_use) + to_be_deleted_bounds.append(n) + inp = [n_pre.interval for n_pre in node.inputs] + if (C is not None and isinstance(node, BoundLinear) + and not node.is_input_perturbed(1)): + # merge the last BoundLinear node with the specification, available + # when weights of this layer are not perturbed + ret = node.interval_propagate(*inp, C=C) + _delete_unused_bounds(to_be_deleted_bounds) + return ret + else: + node.interval = node.interval_propagate(*inp) + + node.lower, node.upper = node.interval + if isinstance(node.lower, torch.Size): + node.lower = torch.tensor(node.lower) + node.interval = (node.lower, node.upper) + if isinstance(node.upper, torch.Size): + node.upper = torch.tensor(node.upper) + node.interval = (node.lower, node.upper) + + if C is not None: + _delete_unused_bounds(to_be_deleted_bounds) + return BoundLinear.interval_propagate(None, node.interval, C=C) + else: + _delete_unused_bounds(to_be_deleted_bounds) + return node.interval + +def _IBP_loss_fusion(self, node, C): + """Merge BoundLinear, BoundGatherElements and BoundSub. + + Improvement when loss fusion is used in training. + """ + + # not using loss fusion + if not self.bound_opts.get('loss_fusion', False): + return None + + # Currently this function has issues in more complicated networks. + if self.bound_opts.get('no_ibp_loss_fusion', False): + return None + + if (C is None and isinstance(node, BoundSub) + and isinstance(node.inputs[1], BoundGatherElements) + and isinstance(node.inputs[0], BoundLinear)): + node_gather = node.inputs[1] + node_linear = node.inputs[0] + node_start = node_linear.inputs[0] + w = node_linear.inputs[1].param + b = node_linear.inputs[2].param + labels = node_gather.inputs[1] + if not hasattr(node_start, 'interval'): + self.IBP_general(node_start) + for n in node_gather.inputs: + if not hasattr(n, 'interval'): + self.IBP_general(n) + if torch.isclose(labels.lower, labels.upper, 1e-8).all(): + labels = labels.lower + batch_size = labels.shape[0] + w = w.expand(batch_size, *w.shape) + w = w - torch.gather( + w, dim=1, + index=labels.unsqueeze(-1).repeat(1, w.shape[1], w.shape[2])) + b = b.expand(batch_size, *b.shape) + b = b - torch.gather(b, dim=1, + index=labels.repeat(1, b.shape[1])) + lower, upper = node_start.interval + lower, upper = lower.unsqueeze(1), upper.unsqueeze(1) + node.lower, node.upper = node_linear.interval_propagate( + (lower, upper), (w, w), (b.unsqueeze(1), b.unsqueeze(1))) + node.interval = node.lower, node.upper = ( + node.lower.squeeze(1), node.upper.squeeze(1)) + return node.interval + + return None + + +def check_IBP_intermediate(self, node): + """ Check if we use IBP bounds to compute intermediate bounds on this node. + Basically we check if we can get bounds by only visiting operators in + `self.ibp_intermediate`. + + Currently, assume all eligible operators have exactly one input. """ + nodes = [] + while not hasattr(node, 'lower') or not hasattr(node, 'upper'): + if type(node) not in self.ibp_intermediate: + return False + nodes.append(node) + node = node.inputs[0] + nodes.reverse() + for n in nodes: + node.interval = self.IBP_general(n) + return True + + +def check_IBP_first_linear(self, node): + """Here we avoid creating a big C matrix in the first linear layer. + Disable this optimization when we have beta for intermediate layer bounds. + Disable this optimization when we need the A matrix of the first nonlinear + layer, forcibly use CROWN to record A matrix. + """ + # This is the list of all intermediate layers where we need to refine. + if self.intermediate_constr is not None: + intermediate_beta_enabled_layers = [ + k for v in self.intermediate_constr.values() for k in v] + else: + intermediate_beta_enabled_layers = [] + + if (node.name not in self.needed_A_dict.keys() + and (type(node) == BoundLinear + or type(node) == BoundConv + and node.name not in intermediate_beta_enabled_layers)): + if type(node.inputs[0]) == BoundInput: + node.lower, node.upper = self.IBP_general(node) + return True + + return False diff --git a/auto_LiRPA/linear_bound.py b/auto_LiRPA/linear_bound.py new file mode 100644 index 0000000..ef50028 --- /dev/null +++ b/auto_LiRPA/linear_bound.py @@ -0,0 +1,33 @@ +class LinearBound: + def __init__( + self, lw=None, lb=None, uw=None, ub=None, lower=None, upper=None, + from_input=None, x_L=None, x_U=None, offset=0, tot_dim=None): + self.lw = lw + self.lb = lb + self.uw = uw + self.ub = ub + self.lower = lower + self.upper = upper + self.from_input = from_input + self.x_L = x_L + self.x_U = x_U + # Offset for input variables. Used for batched forward bound + # propagation. + self.offset = offset + if tot_dim is not None: + self.tot_dim = tot_dim + elif lw is not None: + self.tot_dim = lw.shape[1] + else: + self.tot_dim = 0 + + def is_single_bound(self): + """Check whether the linear lower bound and the linear upper bound are + the same.""" + if (self.lw is not None and self.uw is not None + and self.lb is not None and self.ub is not None): + return (self.lw.data_ptr() == self.uw.data_ptr() + and self.lb.data_ptr() == self.ub.data_ptr() + and self.x_L is not None and self.x_U is not None) + else: + return True diff --git a/auto_LiRPA/operators/__init__.py b/auto_LiRPA/operators/__init__.py index 9020901..56cb4f3 100644 --- a/auto_LiRPA/operators/__init__.py +++ b/auto_LiRPA/operators/__init__.py @@ -1,7 +1,9 @@ from .base import * from .linear import * from .convolution import * +from .pooling import * from .activation import * +from .nonlinear import * from .bivariate import * from .normalization import * from .shape import * @@ -11,5 +13,7 @@ from .constant import * from .leaf import * from .logical import * -from .dropout import * -from .dtype import * \ No newline at end of file +from .dropout import * +from .dtype import * +from .cut_ops import * +from .solver_utils import grb \ No newline at end of file diff --git a/auto_LiRPA/operators/activation.py b/auto_LiRPA/operators/activation.py index 7aa2dca..1383100 100644 --- a/auto_LiRPA/operators/activation.py +++ b/auto_LiRPA/operators/activation.py @@ -1,43 +1,74 @@ """ Activation operators or other unary nonlinear operators""" +from typing import Optional, Tuple +import torch +from torch import Tensor from .base import * +from .clampmult import multiply_by_A_signs +from .solver_utils import grb +from ..utils import unravel_index, logger, prod + + +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) + class BoundActivation(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.nonlinear = True + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.requires_input_bounds = [0] self.relaxed = False def _init_masks(self, x): - self.mask_pos = torch.ge(x.lower, 0).to(torch.float) - self.mask_neg = torch.le(x.upper, 0).to(torch.float) - self.mask_both = 1 - self.mask_pos - self.mask_neg + self.mask_pos = x.lower >= 0 + self.mask_neg = x.upper <= 0 + self.mask_both = torch.logical_not(torch.logical_or(self.mask_pos, self.mask_neg)) - def _init_linear(self, x, dim_opt=None): + def init_linear_relaxation(self, x, dim_opt=None): self._init_masks(x) self.lw = torch.zeros_like(x.lower) self.lb = self.lw.clone() self.uw = self.lw.clone() self.ub = self.lw.clone() - def _add_linear(self, mask, type, k, x0, y0): - if mask is None: - mask = 1 + def add_linear_relaxation(self, mask, type, k, x0, y0): if type == 'lower': w_out, b_out = self.lw, self.lb else: w_out, b_out = self.uw, self.ub - w_out += mask * k - b_out += mask * (-x0 * k + y0) + + if mask is None: + if isinstance(k, Tensor) and k.ndim > 0: + w_out[:] = k + else: + w_out.fill_(k) + else: + if isinstance(k, Tensor): + w_out[..., mask] = k[..., mask].to(w_out) + else: + w_out[..., mask] = k + + if (not isinstance(x0, Tensor) and x0 == 0 + and not isinstance(y0, Tensor) and y0 == 0): + pass + else: + b = -x0 * k + y0 + if mask is None: + if b.ndim > 0: + b_out[:] = b + else: + b_out.fill_(b) + else: + b_out[..., mask] = b[..., mask] def bound_relax(self, x): return not_implemented_op(self, 'bound_relax') - + def interval_propagate(self, *v): - return self.default_interval_propagate(*v) + return self.default_interval_propagate(*v) def bound_backward(self, last_lA, last_uA, x): if not self.relaxed: - self._init_linear(x) + self.init_linear_relaxation(x) self.bound_relax(x) def _bound_oneside(last_A, sign=-1): @@ -45,23 +76,25 @@ def _bound_oneside(last_A, sign=-1): return None, 0 if sign == -1: w_pos, b_pos, w_neg, b_neg = ( - self.lw.unsqueeze(0), self.lb.unsqueeze(0), + self.lw.unsqueeze(0), self.lb.unsqueeze(0), self.uw.unsqueeze(0), self.ub.unsqueeze(0)) else: w_pos, b_pos, w_neg, b_neg = ( - self.uw.unsqueeze(0), self.ub.unsqueeze(0), + self.uw.unsqueeze(0), self.ub.unsqueeze(0), self.lw.unsqueeze(0), self.lb.unsqueeze(0)) + w_pos = maybe_unfold_patches(w_pos, last_A) + w_neg = maybe_unfold_patches(w_neg, last_A) + b_pos = maybe_unfold_patches(b_pos, last_A) + b_neg = maybe_unfold_patches(b_neg, last_A) if self.batch_dim == 0: - _A = last_A.clamp(min=0) * w_pos + last_A.clamp(max=0) * w_neg - _bias = last_A.clamp(min=0) * b_pos + last_A.clamp(max=0) * b_neg - if _bias.ndim > 2: - _bias = torch.sum(_bias, dim=list(range(2, _bias.ndim))) + _A, _bias = multiply_by_A_signs(last_A, w_pos, w_neg, b_pos, b_neg) elif self.batch_dim == -1: + # FIXME: why this is different from above? mask = torch.gt(last_A, 0.).to(torch.float) _A = last_A * (mask * w_pos.unsqueeze(1) + - (1 - mask) * w_neg.unsqueeze(1)) + (1 - mask) * w_neg.unsqueeze(1)) _bias = last_A * (mask * b_pos.unsqueeze(1) + - (1 - mask) * b_neg.unsqueeze(1)) + (1 - mask) * b_neg.unsqueeze(1)) if _bias.ndim > 2: _bias = torch.sum(_bias, dim=list(range(2, _bias.ndim))) else: @@ -74,29 +107,39 @@ def _bound_oneside(last_A, sign=-1): return [(lA, uA)], lbias, ubias + @staticmethod + @torch.jit.script + def bound_forward_w( + relax_lw: Tensor, relax_uw: Tensor, x_lw: Tensor, x_uw: Tensor, dim: int): + lw = (relax_lw.unsqueeze(dim).clamp(min=0) * x_lw + + relax_lw.unsqueeze(dim).clamp(max=0) * x_uw) + uw = (relax_uw.unsqueeze(dim).clamp(max=0) * x_lw + + relax_uw.unsqueeze(dim).clamp(min=0) * x_uw) + return lw, uw + + @staticmethod + @torch.jit.script + def bound_forward_b( + relax_lw: Tensor, relax_uw: Tensor, relax_lb: Tensor, + relax_ub: Tensor, x_lb: Tensor, x_ub: Tensor): + lb = relax_lw.clamp(min=0) * x_lb + relax_lw.clamp(max=0) * x_ub + relax_lb + ub = relax_uw.clamp(max=0) * x_lb + relax_uw.clamp(min=0) * x_ub + relax_ub + return lb, ub + def bound_forward(self, dim_in, x): if not self.relaxed: - self._init_linear(x) + self.init_linear_relaxation(x) self.bound_relax(x) - if self.lw.ndim > 0: - if x.lw is not None: - lw = self.lw.unsqueeze(1).clamp(min=0) * x.lw + \ - self.lw.unsqueeze(1).clamp(max=0) * x.uw - uw = self.uw.unsqueeze(1).clamp(max=0) * x.lw + \ - self.uw.unsqueeze(1).clamp(min=0) * x.uw - else: - lw = uw = None + assert (x.lw is None) == (x.uw is None) + + dim = 1 if self.lw.ndim > 0 else 0 + + if x.lw is not None: + lw, uw = BoundActivation.bound_forward_w(self.lw, self.uw, x.lw, x.uw, dim) else: - if x.lw is not None: - lw = self.lw.unsqueeze(0).clamp(min=0) * x.lw + \ - self.lw.unsqueeze(0).clamp(max=0) * x.uw - uw = self.uw.unsqueeze(0).clamp(min=0) * x.lw + \ - self.uw.unsqueeze(0).clamp(max=0) * x.uw - else: - lw = uw = None - lb = self.lw.clamp(min=0) * x.lb + self.lw.clamp(max=0) * x.ub + self.lb - ub = self.uw.clamp(max=0) * x.lb + self.uw.clamp(min=0) * x.ub + self.ub + lw = uw = None + lb, ub = BoundActivation.bound_forward_b(self.lw, self.uw, self.lb, self.ub, x.lb, x.ub) return LinearBound(lw, lb, uw, ub) @@ -106,47 +149,96 @@ def interval_propagate(self, *v): class BoundOptimizableActivation(BoundActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - # Two stages: `init` (initializing parameters) and `opt` (optimizing parameters). + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + # Stages: + # * `init`: initializing parameters + # * `opt`: optimizing parameters + # * `reuse`: not optimizing parameters but reuse saved values # If `None`, it means activation optimization is currently not used. self.opt_stage = None + self.alpha = OrderedDict() + # Save patch sizes during bound_backward() for each output_node. + self.patch_size = {} + # Location of batch dimension in self.alpha. Must be set by children. + self.alpha_batch_dim = None + # A torch.bool mask of shape Tensor([batch_size]) that conditions the sample of alpha and beta to update + # If set to None, update all samples + # If not None, select those corresponding to 1 to update + self.alpha_beta_update_mask = None - """ Enter the stage for initializing bound optimization. Optimized bounds are not used in - this stage. """ def opt_init(self): + """Enter the stage for initializing bound optimization. Optimized bounds + are not used in this stage.""" self.opt_stage = 'init' - """ Start optimizing bounds """ def opt_start(self): + """Start optimizing bounds.""" self.opt_stage = 'opt' - """ start_nodes: a list of starting nodes [(node, size)] during - CROWN backward bound propagation""" + def opt_reuse(self): + """ Reuse optimizing bounds """ + self.opt_stage = 'reuse' + + def opt_no_reuse(self): + """ Finish reusing optimized bounds """ + if self.opt_stage == 'reuse': + self.opt_stage = None + + def opt_end(self): + """ End optimizing bounds """ + self.opt_stage = None + def init_opt_parameters(self, start_nodes): + """ start_nodes: a list of starting nodes [(node, size)] during + CROWN backward bound propagation""" raise NotImplementedError - def _init_linear(self, x, dim_opt=None): + def clip_alpha_(self): + pass + + def init_linear_relaxation(self, x, dim_opt=None): self._init_masks(x) # The first dimension of size 2 is used for lA and uA respectively, # when computing intermediate bounds. - if self.opt_stage == 'opt' and dim_opt: - self.lw = torch.zeros(2, dim_opt, *x.lower.shape).to(x.lower) + if self.opt_stage in ['opt', 'reuse'] and dim_opt is not None: + # For optimized bounds, we have independent lw for each output dimension for bound optimization. + # If the output layer is a fully connected layer, len(dim_opt) = 1. + # If the output layer is a conv layer, len(dim_opt) = 3 but we only use the out_c dimension to create slopes/bias. + # Variables are shared among out_h, out_w dimensions so far. + dim = dim_opt if isinstance(dim_opt, int) else dim_opt[0] + self.lw = torch.zeros(2, dim, *x.lower.shape).to(x.lower) else: + # Without optimized bounds, the lw, lb (slope, biase) etc only depend on intermediate layer bounds, + # and are shared among different output dimensions. self.lw = torch.zeros_like(x.lower) self.lb = self.lw.clone() self.uw = self.lw.clone() - self.ub = self.lw.clone() + self.ub = self.lw.clone() def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None): - self._start = start_node - - if self.opt_stage != 'opt': - return super().bound_backward(last_lA, last_uA, x) - assert self.batch_dim == 0 + self._start = start_node.name + + if self.opt_stage not in ['opt', 'reuse']: + last_A = last_lA if last_lA is not None else last_uA + # Returned [(lA, uA)], lbias, ubias + As, lbias, ubias = super().bound_backward(last_lA, last_uA, x) + if isinstance(last_A, Patches): + A_prod = As[0][1].patches if As[0][0] is None else As[0][1].patches + # FIXME: Unify this function with BoundReLU + # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters. + if start_node is not None: + if last_A.unstable_idx is not None: + # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w). + self.patch_size[start_node.name] = [last_A.output_shape[1], A_prod.size(1), last_A.output_shape[2], last_A.output_shape[3], A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)] + else: + # Regular patches. + self.patch_size[start_node.name] = A_prod.size() + return As, lbias, ubias + assert self.batch_dim == 0 if not self.relaxed: - self._init_linear(x, dim_opt=start_shape) + self.init_linear_relaxation(x, dim_opt=start_shape) self.bound_relax(x) def _bound_oneside(last_A, sign=-1): @@ -156,11 +248,12 @@ def _bound_oneside(last_A, sign=-1): w_pos, b_pos, w_neg, b_neg = self.lw[0], self.lb[0], self.uw[0], self.ub[0] else: w_pos, b_pos, w_neg, b_neg = self.uw[1], self.ub[1], self.lw[1], self.lb[1] - _A = last_A.clamp(min=0) * w_pos + last_A.clamp(max=0) * w_neg - _bias = last_A.clamp(min=0) * b_pos + last_A.clamp(max=0) * b_neg - if _bias.ndim > 2: - _bias = torch.sum(_bias, list(range(2, _bias.ndim))) - return _A, _bias + w_pos = maybe_unfold_patches(w_pos, last_A) + w_neg = maybe_unfold_patches(w_neg, last_A) + b_pos = maybe_unfold_patches(b_pos, last_A) + b_neg = maybe_unfold_patches(b_neg, last_A) + A_prod, _bias = multiply_by_A_signs(last_A, w_pos, w_neg, b_pos, b_neg) + return A_prod, _bias lA, lbias = _bound_oneside(last_lA, sign=-1) uA, ubias = _bound_oneside(last_uA, sign=+1) @@ -169,692 +262,722 @@ def _bound_oneside(last_A, sign=-1): def _no_bound_parameters(self): raise AttributeError('Bound parameters have not been initialized.' - 'Please call `compute_bounds` with `method=CROWN-optimized`' - ' at least once.') + 'Please call `compute_bounds` with `method=CROWN-optimized`' + ' at least once.') -class BoundLeakyRelu(BoundActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.nonlinear = True - self.options = options.get('relu') - self.alpha = attr['alpha'] + def dump_optimized_params(self): + raise NotImplementedError - @Bound.save_io_shape - def forward(self, x): - return F.leaky_relu(x, negative_slope=self.alpha) + def restore_optimized_params(self): + raise NotImplementedError - def bound_backward(self, last_lA, last_uA, x=None, start_node=None, start_shape=None): - if x is not None: - lb_r = x.lower.clamp(max=0) - ub_r = x.upper.clamp(min=0) - else: - lb_r = self.lower.clamp(max=0) - ub_r = self.upper.clamp(min=0) - ub_r = torch.max(ub_r, lb_r + 1e-8) - upper_d = (ub_r - self.alpha * lb_r) / (ub_r - lb_r) - upper_b = - lb_r * upper_d + self.alpha * lb_r + def set_alpha_beta_update_mask(self, mask): + self.alpha_beta_update_mask = mask - if self.options == "same-slope": - # the same slope for upper and lower - lower_d = upper_d - elif self.options == "zero-lb": - # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN - lower_d = (upper_d >= 1.0).float() + (upper_d < 1.0).float() * self.alpha - elif self.options == "one-lb": - # Always use slope 1 as lower bound - lower_d = (upper_d > 0.0).float() + (upper_d <= 0.0).float() * self.alpha - else: - lower_d = (upper_d > 0.5).float() + (upper_d <= 0.5).float() * self.alpha - - upper_d = upper_d.unsqueeze(0) - lower_d = lower_d.unsqueeze(0) - # Choose upper or lower bounds based on the sign of last_A - uA = lA = None - ubias = lbias = 0 - if last_uA is not None: - neg_uA = last_uA.clamp(max=0) - pos_uA = last_uA.clamp(min=0) - uA = upper_d * pos_uA + lower_d * neg_uA - ubias = self.get_bias(pos_uA, upper_b) - if last_lA is not None: - neg_lA = last_lA.clamp(max=0) - pos_lA = last_lA.clamp(min=0) - lA = upper_d * neg_lA + lower_d * pos_lA - lbias = self.get_bias(neg_lA, upper_b) - return [(lA, uA)], lbias, ubias + def clean_alpha_beta_update_mask(self): + self.alpha_beta_update_mask = None class BoundRelu(BoundOptimizableActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.options = options - self.relu_options = options.get('relu', 'adaptive') + self.relu_options = options.get('relu', 'adaptive') # FIXME: use better names. + self.use_sparse_spec_alpha = options.get('sparse_spec_alpha', False) + self.use_sparse_features_alpha = options.get('sparse_features_alpha', False) self.beta = self.beta_mask = self.masked_beta = self.sparse_beta = None self.split_beta_used = False self.history_beta_used = False self.flattened_nodes = None # Save patches size for each output node. self.patch_size = {} + self.cut_used = False + self.cut_module = None + # Alpha dimension is (2, output_shape, batch, *shape) for ReLU. + self.alpha_batch_dim = 2 def init_opt_parameters(self, start_nodes): - self.alpha = OrderedDict() ref = self.inputs[0].lower # a reference variable for getting the shape - for ns, size_s in start_nodes: - self.alpha[ns] = torch.empty([2, size_s, ref.size(0), *self.shape], - dtype=torch.float, device=ref.device, requires_grad=True) - for k, v in self.alpha.items(): - v.data.copy_(self.lower_d.data) # Initial from adaptive lower bounds. + batch_size = ref.size(0) + self.alpha = OrderedDict() + self.alpha_lookup_idx = OrderedDict() # For alpha with sparse spec dimention. + self.alpha_indices = None # indices of non-zero alphas. + verbosity = self.options.get('verbosity', 0) + + # Alpha can be sparse in both spec dimension, and the C*H*W dimension. + # We first deal with the sparse-feature alpha, which is sparse in the + # C*H*W dimesnion of this layer. + minimum_sparsity = self.options.get('minimum_sparsity', 0.9) + if (hasattr(self.inputs[0], 'lower') and hasattr(self.inputs[0], 'upper') + and self.use_sparse_features_alpha): + # Pre-activation bounds available, we will store the alpha for unstable neurons only. + # Since each element in a batch can have different unstable neurons, + # for simplicity we find a super-set using any(dim=0). + # This can be non-ideal if the x in a batch are very different. + self.alpha_indices = torch.logical_and( + self.inputs[0].lower < 0, self.inputs[0].upper > 0).any(dim=0).nonzero(as_tuple=True) + total_neuron_size = self.inputs[0].lower.numel() // batch_size + if self.alpha_indices[0].size(0) <= minimum_sparsity * total_neuron_size: + # Shape is the number of unstable neurons in this layer. + alpha_shape = [self.alpha_indices[0].size(0)] + # Skip the batch, spec dimension, and find the lower slopes for all unstable neurons. + if len(self.alpha_indices) == 1: + # This layer is after a linear layer. + alpha_init = self.lower_d[:, :, self.alpha_indices[0]] + elif len(self.alpha_indices) == 3: + # This layer is after a conv layer. + alpha_init = self.lower_d[ + :, :, self.alpha_indices[0], self.alpha_indices[1], + self.alpha_indices[2]] + else: + raise ValueError + if verbosity > 0: + print(f'layer {self.name} using sparse-features alpha with shape {alpha_shape}; unstable size {self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({ref.shape})') + else: + alpha_shape = self.shape # Full alpha. + alpha_init = self.lower_d + if verbosity > 0: + print(f'layer {self.name} using full alpha with shape {alpha_shape}; unstable size {self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({ref.shape})') + self.alpha_indices = None # Use full alpha. + else: + alpha_shape = self.shape # Full alpha. + alpha_init = self.lower_d + # Now we start to create alphas for all start nodes. + # When sparse-spec feature is enabled, alpha is created for only + # unstable neurons in start node. + for ns, output_shape, unstable_idx in start_nodes: + if isinstance(output_shape, (list, tuple)): + if len(output_shape) > 1: + size_s = prod(output_shape) # Conv layers. + else: + size_s = output_shape[0] + else: + size_s = output_shape + # unstable_idx may be a tensor (dense layer or conv layer + # with shared alpha), or tuple of 3-d tensors (conv layer with + # non-sharing alpha). + sparsity = float('inf') if unstable_idx is None else unstable_idx.size(0) if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].size(0) + if sparsity <= minimum_sparsity * size_s and self.use_sparse_spec_alpha: + if verbosity > 0: + print(f'layer {self.name} start_node {ns} using sparse-spec alpha with unstable size {sparsity} total_size {size_s} output_shape {output_shape}') + # For fully connected layer, or conv layer with shared alpha per channel. + # shape is (2, sparse_spec, batch, this_layer_shape) + # We create sparse specification dimension, where the spec dimension of alpha only includes slopes for unstable neurons in start_node. + self.alpha[ns] = torch.empty([2, sparsity + 1, batch_size, *alpha_shape], + dtype=torch.float, device=ref.device, requires_grad=True) + self.alpha[ns].data.copy_(alpha_init.data) # This will broadcast to (2, sparse_spec) dimensions. + # unstable_idx is a list of used neurons (or channels for BoundConv) for the start_node. + assert unstable_idx.ndim == 1 if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].ndim == 1 + # We only need to the alpha for the unstable neurons in start_node. + indices = torch.arange(1, sparsity + 1, device=alpha_init.device, dtype=torch.long) + if isinstance(output_shape, int) or len(output_shape) == 1: + # Fully connected layers, or conv layer in patches mode with partially shared alpha (pixels in the same channel use the same alpha). + self.alpha_lookup_idx[ns] = torch.zeros(size_s, dtype=torch.long, device=alpha_init.device) + # This lookup table maps the unstable_idx to the actual alpha location in self.alpha[ns]. + # Note that self.alpha[ns][:,0] is reserved for any unstable neurons that are not found in the lookup table. This usually should not + # happen, unless reference bounds are not properly set. + self.alpha_lookup_idx[ns].data[unstable_idx] = indices + else: + # conv layer in matrix mode, or in patches mode but with non-shared alpha. The lookup table is 3-d. + assert len(output_shape) == 3 + self.alpha_lookup_idx[ns] = torch.zeros(output_shape, dtype=torch.long, device=alpha_init.device) + if isinstance(unstable_idx, torch.Tensor): + # Convert the unstable index from flattend 1-d to 3-d. (matrix mode). + unstable_idx_3d = unravel_index(unstable_idx, output_shape) + else: + # Patches mode with non-shared alpha, unstable_idx is already 3d. + unstable_idx_3d = unstable_idx + # Build look-up table. + self.alpha_lookup_idx[ns].data[unstable_idx_3d[0], unstable_idx_3d[1], unstable_idx_3d[2]] = indices + else: + if verbosity > 0: + print(f'layer {self.name} start_node {ns} using full alpha with unstable size {sparsity if unstable_idx is not None else None} total_size {size_s} output_shape {output_shape}') + # alpha shape is (2, spec, batch, this_layer_shape). "this_layer_shape" may still be sparse. + self.alpha[ns] = torch.empty([2, size_s, batch_size, *alpha_shape], + dtype=torch.float, device=ref.device, requires_grad=True) + self.alpha[ns].data.copy_(alpha_init.data) # This will broadcast to (2, spec) dimensions + # alpha_lookup_idx can be used for checking if sparse alpha is used or not. + self.alpha_lookup_idx[ns] = None + + def clip_alpha_(self): + for v in self.alpha.values(): + v.data = torch.clamp(v.data, 0., 1.) - @Bound.save_io_shape def forward(self, x): self.shape = x.shape[1:] if self.flattened_nodes is None: self.flattened_nodes = x[0].reshape(-1).shape[0] return F.relu(x) - # Linear relaxation for nonlinear functions - # Used for forward mode bound propagation - def bound_relax(self, x): - # FIXME maybe avoid using `mask` which looks inefficient - # m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99) - self._add_linear(mask=self.mask_neg, type='lower', - k=torch.zeros_like(x.lower), x0=0, y0=0) - self._add_linear(mask=self.mask_neg, type='upper', - k=torch.zeros_like(x.lower), x0=0, y0=0) - self._add_linear(mask=self.mask_pos, type='lower', - k=torch.ones_like(x.lower), x0=0, y0=0) - self._add_linear(mask=self.mask_pos, type='upper', - k=torch.ones_like(x.lower), x0=0, y0=0) - upper = torch.max(x.upper, x.lower + 1e-8) - delta = 1e-8 - r = (x.upper - x.lower).clamp(min=delta) - upper_k = x.upper / r + delta / r - self._add_linear(mask=self.mask_both, type='upper', - k=upper_k, x0=x.lower, y0=0) - if self.relu_options == "same-slope": + def _forward_relaxation(self, x): + self._init_masks(x) + self.mask_pos = self.mask_pos.to(x.lower) + self.mask_both = self.mask_both.to(x.lower) + + upper_k, upper_b = self._relu_upper_bound(x.lower, x.upper) + self.uw = self.mask_pos + self.mask_both * upper_k + self.ub = self.mask_both * upper_b + + if self.opt_stage in ['opt', 'reuse']: + # Each actual alpha in the forward mode has shape (batch_size, *relu_node_shape]. + # But self.alpha has shape (2, output_shape, batch_size, *relu_node_shape] + # and we do not need its first two dimensions. + lower_k = alpha = self.alpha['_forward'][0, 0] + elif self.relu_options == "same-slope": lower_k = upper_k elif self.relu_options == "zero-lb": lower_k = torch.zeros_like(upper_k) elif self.relu_options == "one-lb": lower_k = torch.ones_like(upper_k) - elif self.opt_stage == 'opt': - # Each actual alpha in the forward mode has shape (batch_size, *relu_node_shape]. - # But self.alpha has shape (2, output_shape, batch_size, *relu_node_shape] - # and we do not need its first two dimensions. - lower_k = alpha = self.alpha['_forward'][0, 0] else: # adaptive lower_k = torch.gt(torch.abs(x.upper), torch.abs(x.lower)).to(torch.float) # NOTE #FIXME Saved for initialization bounds for optimization. # In the backward mode, same-slope bounds are used. # But here it is using adaptive bounds which seem to be better - # for nn4sys benchmark with loose input bounds. Need confirmation + # for nn4sys benchmark with loose input bounds. Need confirmation # for other cases. - self.d = lower_k.detach() # saved for initializing optimized bounds - self._add_linear(mask=self.mask_both, type='lower', - k=lower_k, x0=0., y0=0.) + self.lower_d = lower_k.detach() # saved for initializing optimized bounds - def bound_backward(self, last_lA, last_uA, x=None, start_node=None, beta_for_intermediate_layers=False, unstable_idx=None): - if x is not None: - # # only set lower and upper bound here when using neuron set version, ie, not ob_update_by_layer - # if self.beta is not None and not self.options.get('optimize_bound_args', {}).get('ob_update_by_layer', False): - # if self.beta_mask.abs().sum() != 0: - # # set bound neuron-wise according to beta_mask - # x.lower = x.lower * (self.beta_mask != 1).to(torch.float32) - # x.upper = x.upper * (self.beta_mask != -1).to(torch.float32) + self.lw = self.mask_both * lower_k + self.mask_pos - lb_r = x.lower.clamp(max=0) - ub_r = x.upper.clamp(min=0) - else: - lb_r = self.lower.clamp(max=0) - ub_r = self.upper.clamp(min=0) + def bound_dynamic_forward(self, x, max_dim=None, offset=0): + self._init_masks(x) + self.mask_pos = self.mask_pos.to(x.lower) + self.mask_both = self.mask_both.to(x.lower) + + upper_k, upper_b = self._relu_upper_bound(x.lower, x.upper) + w_new = (self.mask_pos.unsqueeze(1) * x.lw + + self.mask_both.unsqueeze(1) * upper_k.unsqueeze(1) * x.lw) + upper_b = self.mask_both * upper_b / 2 + b_new = (self.mask_pos * x.lb + + self.mask_both * upper_k * x.lb + upper_b) + + # Create new variables for unstable ReLU + batch_size = w_new.shape[0] + device = w_new.device + unstable = self.mask_both.view(batch_size, -1) + tot_unstable = int(unstable.sum(dim=-1).max()) + tot_dim = x.tot_dim + tot_unstable + # logger.debug(f'Unstable: {tot_unstable}') + + if offset + w_new.shape[1] < x.tot_dim: + return LinearBound( + w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=tot_dim) + + index = torch.cumsum(unstable, dim=-1).to(torch.int64) + index = (index - (offset + w_new.shape[1] - x.tot_dim)).clamp(min=0) + num_new_dim = int(index.max()) + num_new_dim_actual = min(num_new_dim, max_dim - w_new.shape[1]) + index = index.clamp(max=num_new_dim_actual+1) + w_unstable = torch.zeros(batch_size, num_new_dim_actual + 2, unstable.size(-1), device=device) + x_L_unstable = -torch.ones(batch_size, num_new_dim_actual, device=device) + x_U_unstable = torch.ones(batch_size, num_new_dim_actual, device=device) + w_unstable.scatter_(dim=1, index=index.unsqueeze(1), src=upper_b.view(batch_size, 1, -1), reduce='add') + w_unstable = w_unstable[:, 1:-1].view(batch_size, num_new_dim_actual, *w_new.shape[2:]) + + w_new = torch.cat([w_new, w_unstable], dim=1) + x_L_new = torch.cat([x.x_L, x_L_unstable], dim=-1) + x_U_new = torch.cat([x.x_U, x_U_unstable], dim=-1) + + return LinearBound( + w_new, b_new, w_new, b_new, x_L=x_L_new, x_U=x_U_new, tot_dim=tot_dim) - self.I = ((lb_r != 0) * (ub_r != 0)).detach() # unstable neurons - # print('unstable neurons:', self.I.sum()) - if hasattr(x, 'interval') and Interval.use_relative_bounds(x.interval): - diff_x = x.interval.upper_offset - x.interval.lower_offset - upper_d = (self.interval.upper_offset - self.interval.lower_offset) / diff_x.clamp(min=epsilon) - mask_tiny_diff = (diff_x <= epsilon).float() - upper_d = mask_tiny_diff * F.relu(x.upper) + (1 - mask_tiny_diff) * upper_d + def bound_forward(self, dim_in, x): + self._forward_relaxation(x) + + lb = self.lw * x.lb + ub = self.uw * x.ub + self.ub + + if x.lw is not None: + lw = self.lw.unsqueeze(1) * x.lw + else: + lw = None + if x.uw is not None: + uw = self.uw.unsqueeze(1) * x.uw else: - ub_r = torch.max(ub_r, lb_r + 1e-8) - upper_d = ub_r / (ub_r - lb_r) + uw = None + + if not lw.requires_grad: + del self.mask_both, self.mask_pos + del self.lw, self.uw, self.ub + + return LinearBound(lw, lb, uw, ub) + + @staticmethod + @torch.jit.script + def _relu_upper_bound(lb, ub): + """Upper bound slope and intercept according to CROWN relaxation.""" + # TODO: pre-comple all JIT functions before run. + lb_r = lb.clamp(max=0) + ub_r = ub.clamp(min=0) + ub_r = torch.max(ub_r, lb_r + 1e-8) + upper_d = ub_r / (ub_r - lb_r) upper_b = - lb_r * upper_d + return upper_d, upper_b + + @staticmethod + def _relu_mask_alpha(lower, upper, lb_lower_d : Optional[Tensor], ub_lower_d : Optional[Tensor]) -> Tuple[Optional[Tensor], Optional[Tensor], Tensor]: + lower_mask = (lower >= 0).requires_grad_(False).to(lower.dtype) + upper_mask = (upper <= 0).requires_grad_(False) + zero_coeffs = upper_mask.all() + no_mask = (1. - lower_mask) * (1. - upper_mask.to(upper.dtype)) + if lb_lower_d is not None: + lb_lower_d = torch.clamp(lb_lower_d, min=0., max=1.) * no_mask + lower_mask + if ub_lower_d is not None: + ub_lower_d = torch.clamp(ub_lower_d, min=0., max=1.) * no_mask + lower_mask + return lb_lower_d, ub_lower_d, zero_coeffs + + def _backward_relaxation(self, last_lA, last_uA, x, start_node, unstable_idx): + if x is not None: + lower = x.lower + upper = x.upper + else: + lower = self.lower + upper = self.upper + + # Upper bound slope and intercept according to CROWN relaxation. + upper_d, upper_b = self._relu_upper_bound(lower, upper) flag_expand = False ub_lower_d = lb_lower_d = None - if self.relu_options == "same-slope": - # the same slope for upper and lower - lower_d = upper_d - elif self.relu_options == "zero-lb": - # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN - lower_d = (upper_d >= 1.0).float() - elif self.relu_options == "one-lb": - # Always use slope 1 as lower bound - lower_d = (upper_d > 0.0).float() - elif self.relu_options == "reversed-adaptive": - lower_d = (upper_d < 0.5).float() - elif self.opt_stage == 'opt': + lower_b = None # ReLU does not have lower bound intercept (=0). + alpha_lookup_idx = None # For sparse-spec alpha. + if self.opt_stage in ['opt', 'reuse']: # Alpha-CROWN. lower_d = None # Each alpha has shape (2, output_shape, batch_size, *relu_node_shape]. # If slope is shared, output_shape will be 1. + # The *relu_node_shape might be sparse (sparse-feature alpha), where the non-zero values are indicated by self.alpha_indices. + # The out_shape might be sparse (sparse-spec alpha), where the non-zero values are indexed by self.alpha_lookup_idx. if unstable_idx is not None and self.alpha[start_node.name].size(1) != 1: + # print(f'relu layer {self.name}, start_node {start_node}, unstable_idx {type(unstable_idx)} alpha idx {self.alpha_lookup_idx[start_node.name].size()}') + alpha_lookup_idx = self.alpha_lookup_idx[start_node.name] if isinstance(unstable_idx, tuple): - if isinstance(last_lA, torch.Tensor) or isinstance(last_uA, torch.Tensor): - # Patches mode converted to matrix. Need to select accross the spec dimension. + # Start node is a conv node. + selected_alpha = self.alpha[start_node.name] + if isinstance(last_lA, Tensor) or isinstance(last_uA, Tensor): + # Start node is a conv node but we received tensors as A matrices. + # Patches mode converted to matrix, or matrix mode used. Need to select accross the spec dimension. # For this node, since it is in matrix mode, the spec dimension is out_c * out_h * out_w - selected_alpha = self.alpha[start_node.name] - # Reshape the spec dimension to c*h*w so we can select based on unstable index. - selected_alpha = selected_alpha.view(-1, *start_node.output_shape[1:], *selected_alpha.shape[2:]) - selected_alpha = selected_alpha[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] + # Shape is [2, spec, batch, *this_layer_shape] + if alpha_lookup_idx is None: + # Reshape the spec dimension to c*h*w so we can select used alphas based on unstable index. + # Shape becomes [2, out_c, out_h, out_w, batch, *this_layer_shape] + selected_alpha = selected_alpha.view(selected_alpha.size(0), *start_node.output_shape[1:], *selected_alpha.shape[2:]) + selected_alpha = selected_alpha[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] + else: + assert alpha_lookup_idx.ndim == 3 + # We only stored some alphas, and A is also sparse, so the unstable_idx must be first translated to real indices. + # alpha shape is (2, sparse_spec_shape, batch_size, *relu_node_shape) where relu_node_shape can also be sparse. + # We use sparse-spec alphas. Need to convert these unstable_idx[0], unstable_idx[1], unstable_idx[0] using lookup table. + _unstable_idx = alpha_lookup_idx[unstable_idx[0], unstable_idx[1], unstable_idx[2]] + selected_alpha = self.non_deter_index_select(selected_alpha, index=_unstable_idx, dim=1) else: - # unstable index for patches mode. Need to choose based on unstable_idx. - # Selection is just based on output channel, and it will be selected when d_pos_unfolded_r and d_neg_unfolded_r are constructed. - selected_alpha = self.alpha[start_node.name] + # Patches mode. Alpha must be selected after unfolding, so cannot be done here. + # Selection is deferred to maybe_unfold() using alpha_lookup_idx. + # For partially shared alpha, its shape is (2, out_c, batch_size, *relu_node_shape). + # For full alpha, its shape is (2, out_c*out_h*out_w, batch_size, *relu_node_shape). + # Both the spec dimension and relu_node_shape dimensions can be sparse. + pass elif unstable_idx.ndim == 1: + # Start node is a FC node. # Only unstable neurons of the start_node neurons are used. - selected_alpha = self.non_deter_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) + assert alpha_lookup_idx is None or alpha_lookup_idx.ndim == 1 + _unstable_idx = alpha_lookup_idx[unstable_idx] if alpha_lookup_idx is not None else unstable_idx + selected_alpha = self.non_deter_index_select(self.alpha[start_node.name], index=_unstable_idx, dim=1) elif unstable_idx.ndim == 2: + assert alpha_lookup_idx is None, "sparse spec alpha has not been implemented yet." # Each element in the batch selects different neurons. selected_alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) else: raise ValueError else: + # Spec dimension is dense. Alpha must not be created sparsely. + assert self.alpha_lookup_idx[start_node.name] is None selected_alpha = self.alpha[start_node.name] - # print(f'{self.name} selecting {start_node.name} alpha {selected_alpha.size()}') # The first dimension is lower/upper intermediate bound. - if x is not None: - lower = x.lower - upper = x.upper - else: - lower = self.lower - upper = self.upper - lower_mask = lower > 0 - upper_mask = upper < 0 if last_lA is not None: - lb_lower_d = selected_alpha[0].clamp(min=0.0, max=1.0) - lb_lower_d[:, lower_mask] = 1.0 - lb_lower_d[:, upper_mask] = 0.0 + lb_lower_d = selected_alpha[0] if last_uA is not None: - ub_lower_d = selected_alpha[1].clamp(min=0.0, max=1.0) - ub_lower_d[:, lower_mask] = 1.0 - ub_lower_d[:, upper_mask] = 0.0 - self.zero_backward_coeffs_l = self.zero_backward_coeffs_u = upper_mask.all().item() - flag_expand = True + ub_lower_d = selected_alpha[1] + + if self.alpha_indices is not None: + # Sparse alpha on the hwc dimension. We store slopes for unstable neurons in this layer only. + # Recover to full alpha first. + def reconstruct_full_alpha(sparse_alpha, full_alpha_shape, alpha_indices): + full_alpha = torch.zeros(full_alpha_shape, dtype=sparse_alpha.dtype, device=sparse_alpha.device) + if len(alpha_indices) == 1: + # Relu after a dense layer. + full_alpha[:, :, alpha_indices[0]] = sparse_alpha + elif len(alpha_indices) == 3: + # Relu after a conv layer. + full_alpha[:, :, alpha_indices[0], alpha_indices[1], alpha_indices[2]] = sparse_alpha + else: + raise ValueError + return full_alpha + sparse_alpha_shape = lb_lower_d.shape if lb_lower_d is not None else ub_lower_d.shape + full_alpha_shape = sparse_alpha_shape[:-1] + self.shape + if lb_lower_d is not None: + lb_lower_d = reconstruct_full_alpha(lb_lower_d, full_alpha_shape, self.alpha_indices) + if ub_lower_d is not None: + ub_lower_d = reconstruct_full_alpha(ub_lower_d, full_alpha_shape, self.alpha_indices) + + # condition only on the masked part + if self.alpha_beta_update_mask is not None: + if lb_lower_d is not None: + lb_lower_d_new = lb_lower_d[:, self.alpha_beta_update_mask] + else: + lb_lower_d_new = None + if ub_lower_d is not None: + ub_lower_d_new = ub_lower_d[:, self.alpha_beta_update_mask] + else: + ub_lower_d_new = None + lb_lower_d, ub_lower_d, zero_coeffs = self._relu_mask_alpha(lower, upper, lb_lower_d_new, ub_lower_d_new) + else: + lb_lower_d, ub_lower_d, zero_coeffs = self._relu_mask_alpha(lower, upper, lb_lower_d, ub_lower_d) + self.zero_backward_coeffs_l = self.zero_backward_coeffs_u = zero_coeffs + flag_expand = True # we already have the spec dimension. + elif self.relu_options == "same-slope": + # the same slope for upper and lower + lower_d = upper_d + elif self.relu_options == "zero-lb": + # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN + lower_d = (upper_d >= 1.0).to(upper_d.dtype) + elif self.relu_options == "one-lb": + # Always use slope 1 as lower bound + lower_d = (upper_d > 0.0).to(upper_d.dtype) + elif self.relu_options == "reversed-adaptive": + lower_d = (upper_d < 0.5).to(upper_d.dtype) else: # adaptive - lower_d = (upper_d > 0.5).float() - - # save for calculate babsr score - self.d = upper_d - self.lA = last_lA - # Save for initialization bounds. - self.lower_d = lower_d - - # assert self.I.sum() == torch.logical_and(0 < self.d, self.d < 1).sum() + lower_d = (upper_d > 0.5).to(upper_d.dtype) # Upper bound always needs an extra specification dimension, since they only depend on lb and ub. upper_d = upper_d.unsqueeze(0) upper_b = upper_b.unsqueeze(0) if not flag_expand: - if self.opt_stage == 'opt': + if self.opt_stage in ['opt', 'reuse']: # We have different slopes for lower and upper bounds propagation. lb_lower_d = lb_lower_d.unsqueeze(0) if last_lA is not None else None ub_lower_d = ub_lower_d.unsqueeze(0) if last_uA is not None else None else: lower_d = lower_d.unsqueeze(0) + return upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d, alpha_lookup_idx - mode = "patches" if isinstance(last_lA, Patches) or isinstance(last_uA, Patches) else "matrix" - - # In patches mode, we need to unfold lower and upper slopes. In matrix mode we simply return. - def _maybe_unfold(d_tensor, last_A): - if mode == "matrix" or d_tensor is None or last_A is None: - return d_tensor - # Input are slopes with shape (spec, batch, input_c, input_h, input_w) - # Here spec is the same as out_c. - assert d_tensor.ndim == 5 - d_shape = d_tensor.size() - # Reshape to 4-D tensor to unfold. - d_tensor = d_tensor.view(-1, *d_shape[-3:]) - # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W). - d_unfolded = inplace_unfold(d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding) - # Reshape to (spec, batch, out_h, out_w, in_c, H, W); here spec_size is out_c. - d_unfolded_r = d_unfolded.view(*d_shape[:-3], *d_unfolded.shape[1:]) - if last_A.unstable_idx is not None: - if d_unfolded_r.size(0) == 1: - # Broadcast the spec shape, so only need to select the reset dimensions. - # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W). - d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5) - d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] - # output shape: (unstable_size, batch, in_c, H, W). - else: - d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]] - # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W). - # For regular patches, the shape after unfold is (spec, batch, out_h, out_w, in_c, H, W). - return d_unfolded_r + def bound_backward(self, last_lA, last_uA, x=None, start_node=None, beta_for_intermediate_layers=False, unstable_idx=None): + # Get element-wise CROWN linear relaxations. + upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d, alpha_lookup_idx = \ + self._backward_relaxation(last_lA, last_uA, x, start_node, unstable_idx) + # save for calculate babsr score + self.d = upper_d + self.lA = last_lA + # Save for initialization bounds. + self.lower_d = lower_d # Choose upper or lower bounds based on the sign of last_A def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg): if last_A is None: return None, 0 + # Obtain the new linear relaxation coefficients based on the signs in last_A. + _A, _bias = multiply_by_A_signs(last_A, d_pos, d_neg, b_pos, b_neg) + if isinstance(last_A, Patches): + # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters. + A_prod = _A.patches + if start_node is not None: + if last_A.unstable_idx is not None: + # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w). + self.patch_size[start_node.name] = [last_A.output_shape[1], A_prod.size(1), last_A.output_shape[2], last_A.output_shape[3], A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)] + else: + # Regular patches. + self.patch_size[start_node.name] = A_prod.size() + return _A, _bias - if type(last_A) == torch.Tensor: - # multiply according to sign of A (we use fused operation to save memory) - # neg_A = last_A.clamp(max=0) - # pos_A = last_A.clamp(min=0) - # A = d_pos * pos_A + d_neg * neg_A - A, pos_A, neg_A = self.clamp_mutiply(last_A, d_pos, d_neg) - bias = 0 - if b_pos is not None: - bias = bias + torch.einsum('sb...,sb...->sb', pos_A, b_pos) - if b_neg is not None: - bias = bias + torch.einsum('sb...,sb...->sb', neg_A, b_neg) - return A, bias - elif type(last_A) == Patches: - # if last_A is not an identity matrix - assert last_A.identity == 0 - if last_A.identity == 0: - # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension. - # or (unstable_size, batch_size, in_c, H, W) when it is sparse. - patches = last_A.patches - prod, pos_A_patches, neg_A_patches = self.clamp_mutiply_non_contiguous(patches, d_pos, d_neg) - # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse. - - # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters. - if start_node is not None: - if last_A.unstable_idx is not None: - # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w). - self.patch_size[start_node.name] = [last_A.output_shape[1], prod.size(1), last_A.output_shape[2], last_A.output_shape[3], prod.size(-3), prod.size(-2), prod.size(-1)] - else: - # Regular patches. - self.patch_size[start_node.name] = prod.size() - - bias = 0 - if b_pos is not None: - # For sparse patches the return bias size is (unstable_size, batch). - # For regular patches the return bias size is (spec, batch, out_h, out_w). - bias = bias + torch.einsum('sb...chw,sb...chw->sb...', b_pos, pos_A_patches) - if b_neg is not None: - bias = bias + torch.einsum('sb...chw,sb...chw->sb...', b_neg, neg_A_patches) - return Patches(prod, last_A.stride, last_A.padding, prod.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape), bias + ######## A problem with patches mode for cut constraint start ########## + # There are cases that the node that is in the constraint but not selected by the patches for the output node + # trick: only count the small patches that have all the split node coeffs[ci].sum() equal to coeffs_unfolded[ci][out_h, out_w, -1].sum() + # we should force these beta to be 0 to disable the effect of these constraints + A = last_lA if last_lA is not None else last_uA + current_layer_shape = x.lower.size()[1:] + if self.cut_used and type(A) is Patches: + self.cut_module.patch_trick(start_node, self.name, A, current_layer_shape) + ######## A problem with patches mode for cut constraint end ########## + + if self.cut_used: + # propagate postrelu node in cut constraints + last_lA, last_uA = self.cut_module.relu_cut( + start_node, self.name, last_lA, last_uA, current_layer_shape, unstable_idx, + batch_mask=self.alpha_beta_update_mask) # In patches mode we might need an unfold. - upper_d = _maybe_unfold(upper_d, last_lA if last_lA is not None else last_uA) - lower_d = _maybe_unfold(lower_d, last_lA if last_lA is not None else last_uA) - upper_b = _maybe_unfold(upper_b, last_lA if last_lA is not None else last_uA) - ub_lower_d = _maybe_unfold(ub_lower_d, last_uA) - lb_lower_d = _maybe_unfold(lb_lower_d, last_lA) - - uA, ubias = _bound_oneside(last_uA, upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, None) - lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, upper_d, None, upper_b) - - self.masked_beta_lower = self.masked_beta_upper = None - if self.options.get('optimize_bound_args', {}).get('ob_beta', False): - if self.options.get('optimize_bound_args', {}).get('ob_single_node_split', False): - # Beta-CROWN. - A = last_uA if last_uA is not None else last_lA - if type(A) is Patches: + # lower_d, upper_d, lower_b, upper_b: 1, batch, current_c, current_w, current_h or None + upper_d = maybe_unfold_patches(upper_d, last_lA if last_lA is not None else last_uA) + lower_d = maybe_unfold_patches(lower_d, last_lA if last_lA is not None else last_uA) + upper_b = maybe_unfold_patches(upper_b, last_lA if last_lA is not None else last_uA) + lower_b = maybe_unfold_patches(lower_b, last_lA if last_lA is not None else last_uA) # for ReLU it is always None; keeping it here for completeness. + # ub_lower_d and lb_lower_d might have sparse spec dimension, so they may need alpha_lookup_idx to convert to actual spec dim. + ub_lower_d = maybe_unfold_patches(ub_lower_d, last_uA, alpha_lookup_idx=alpha_lookup_idx) + # optimizable slope lb_lower_d: spec (only channels in spec layer), batch, current_c, current_w, current_h + # patches mode lb_lower_d after unfold: unstable, batch, in_C, H, W + lb_lower_d = maybe_unfold_patches(lb_lower_d, last_lA, alpha_lookup_idx=alpha_lookup_idx) + + if self.cut_used: + I = (x.lower < 0) * (x.upper > 0) + # propagate integer var of relu neuron (arelu) in cut constraints through relu layer + lA, uA, lbias, ubias = self.cut_module.arelu_cut( + start_node, self.name, last_lA, last_uA, lower_d, upper_d, + lower_b, upper_b, lb_lower_d, ub_lower_d, I, x, self.patch_size, + current_layer_shape, unstable_idx, + batch_mask=self.alpha_beta_update_mask) + else: + uA, ubias = _bound_oneside( + last_uA, upper_d, ub_lower_d if lower_d is None else lower_d, + upper_b, lower_b) + lA, lbias = _bound_oneside( + last_lA, lb_lower_d if lower_d is None else lower_d, upper_d, + lower_b, upper_b) + + # Regular Beta CROWN with single neuron split + def _beta_crown_single_neuron_splits(A, uA, lA, unstable_idx): + if type(A) is Patches: + if self.options.get('enable_opt_interm_bounds', False): + # expand sparse_beta to full beta + beta_values = (self.sparse_beta[start_node.name] * self.sparse_beta_sign[start_node.name]) + beta_indices = self.sparse_beta_loc[start_node.name] + self.masked_beta = torch.zeros(2, *self.shape).reshape(2, -1).to(A.patches.dtype) + self.non_deter_scatter_add(self.masked_beta, dim=1, index=beta_indices, src=beta_values.to(self.masked_beta.dtype)) + self.masked_beta = self.masked_beta.reshape(2, *self.shape) + else: + if self.beta is None: + # Beta not used. + return lA, uA # For patches mode, masked_beta will be used; sparse beta is not supported. self.masked_beta = (self.beta[0] * self.beta_mask).requires_grad_() - A_patches = A.patches - # unfold the beta as patches, size (batch, out_h, out_w, in_c, H, W) - masked_beta_unfolded = inplace_unfold(self.masked_beta, kernel_size=A_patches.shape[-2:], padding=A.padding, stride=A.stride) - if A.unstable_idx is not None: - masked_beta_unfolded = masked_beta_unfolded.permute(1, 2, 0, 3, 4) - # After selection, the shape is (unstable_size, batch, in_c, H, W). - masked_beta_unfolded = masked_beta_unfolded[A.unstable_idx[1], A.unstable_idx[2]] - else: - # Add the spec (out_c) dimension. - masked_beta_unfolded = masked_beta_unfolded.unsqueeze(0) - if uA is not None: - uA = Patches(uA.patches + masked_beta_unfolded, uA.stride, uA.padding, uA.patches.shape, unstable_idx=uA.unstable_idx, output_shape=uA.output_shape) - if lA is not None: - lA = Patches(lA.patches - masked_beta_unfolded, lA.stride, lA.padding, lA.patches.shape, unstable_idx=lA.unstable_idx, output_shape=lA.output_shape) - elif type(A) is torch.Tensor: + # unfold the beta as patches, size (batch, out_h, out_w, in_c, H, W) + A_patches = A.patches + masked_beta_unfolded = inplace_unfold(self.masked_beta, kernel_size=A_patches.shape[-2:], padding=A.padding, stride=A.stride, inserted_zeros=A.inserted_zeros, output_padding=A.output_padding) + if A.unstable_idx is not None: + masked_beta_unfolded = masked_beta_unfolded.permute(1, 2, 0, 3, 4, 5) + # After selection, the shape is (unstable_size, batch, in_c, H, W). + masked_beta_unfolded = masked_beta_unfolded[A.unstable_idx[1], A.unstable_idx[2]] + else: + # Add the spec (out_c) dimension. + masked_beta_unfolded = masked_beta_unfolded.unsqueeze(0) + if self.alpha_beta_update_mask is not None: + masked_beta_unfolded = masked_beta_unfolded[self.alpha_beta_update_mask] + if uA is not None: + uA = uA.create_similar(uA.patches + masked_beta_unfolded) + if lA is not None: + lA = lA.create_similar(lA.patches - masked_beta_unfolded) + elif type(A) is Tensor: + if self.options.get('enable_opt_interm_bounds', False): + # For matrix mode, beta is sparse. + beta_values = (self.sparse_beta[start_node.name] * self.sparse_beta_sign[start_node.name]).expand(lA.size(0), -1, -1) + # self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension. + beta_indices = self.sparse_beta_loc[start_node.name].unsqueeze(0).expand(lA.size(0), -1, -1) + else: # For matrix mode, beta is sparse. beta_values = (self.sparse_beta * self.sparse_beta_sign).expand(lA.size(0), -1, -1) # self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension. beta_indices = self.sparse_beta_loc.unsqueeze(0).expand(lA.size(0), -1, -1) - # For conv layer, the last dimension is flattened in indices. - prev_size = A.size() - if uA is not None: - uA = self.non_deter_scatter_add(uA.view(uA.size(0), uA.size(1), -1), dim=2, index=beta_indices, src=beta_values) - uA = uA.view(prev_size) - if lA is not None: - lA = self.non_deter_scatter_add(lA.view(lA.size(0), lA.size(1), -1), dim=2, index=beta_indices, src=beta_values.neg()) - lA = lA.view(prev_size) - else: - raise RuntimeError(f"Unknown type {type(A)} for A") - # The code block below is for debugging and will be removed (until the end of this function). - elif not self.options.get('optimize_bound_args', {}).get('ob_single_node_split', True): - A = uA if uA is not None else lA - if type(A) == torch.Tensor: - device = A.device - else: - device = A.patches.device - print_time = False - - if self.single_beta_used or self.split_beta_used or self.history_beta_used: - start_time = time.time() - history_compute_time, split_compute_time, split_convert_time = 0, 0, 0 - history_compute_time1, history_compute_time2 = 0, 0 - # assert len(self.split_beta) > 0, "split_beta_used or history_beta_used is True means there have to be one relu in one batch is used in split constraints" - if self.single_beta_used: - if beta_for_intermediate_layers: - # We handle the refinement of intermediate layer after this split layer here. (the refinement for intermediate layers before the split is handled in compute_bounds(). - # print(f'single node beta for {start_node.name} with beta shape {self.single_intermediate_betas[start_node.name]["ub"].size()}') - assert not self.history_beta_used - assert not self.history_beta_used - assert type(A) is not Patches - if uA is not None: - # The beta for start_node has shape ([batch, prod(start_node.shape), n_max_history_beta]) - single_intermediate_beta = self.single_intermediate_betas[start_node.name]['ub'] - single_intermediate_beta = single_intermediate_beta.view( - single_intermediate_beta.size(0), -1, single_intermediate_beta.size(-1)) - if unstable_idx is not None: - # Only unstable neurons of the start_node neurons are used. - single_intermediate_beta = self.non_deter_index_select(single_intermediate_beta, index=unstable_idx, dim=1) - # This is the sign. - single_intermediate_beta = single_intermediate_beta * self.single_beta_sign.unsqueeze(1) - # We now generate a large matrix in shape (batch, prod(start_node.shape), prod(nodes)) which is the same size as uA and lA. - prev_size = uA.size() - # self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension. - indices = self.single_beta_loc.unsqueeze(0).expand(uA.size(0), -1, -1) - # We update uA here directly using sparse operation. Note the spec dimension is at the first! - uA = self.non_deter_scatter_add(uA.view(uA.size(0), uA.size(1), -1), dim=2, index=indices, src=single_intermediate_beta.transpose(0,1)) - uA = uA.view(prev_size) - if lA is not None: - # The beta for start_node has shape ([batch, prod(start_node.shape), n_max_history_beta]) - single_intermediate_beta = self.single_intermediate_betas[start_node.name]['lb'] - single_intermediate_beta = single_intermediate_beta.view( - single_intermediate_beta.size(0), -1, single_intermediate_beta.size(-1)) - if unstable_idx is not None: - # Only unstable neurons of the start_node neurons are used. - single_intermediate_beta = self.non_deter_index_select(single_intermediate_beta, index=unstable_idx, dim=1) - # This is the sign, for lower bound we need to negate. - single_intermediate_beta = single_intermediate_beta * ( - self.single_beta_sign.unsqueeze(1)) - # We now generate a large matrix in shape (batch, prod(start_node.shape), prod(nodes)) which is the same size as uA and lA. - prev_size = lA.size() - # self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension. - indices = self.single_beta_loc.unsqueeze(0).expand(lA.size(0), -1, -1) - # We update lA here directly using sparse operation. Note the spec dimension is at the first! - lA = self.non_deter_scatter_add(lA.view(lA.size(0), lA.size(1), -1), dim=2, index=indices, src=single_intermediate_beta.transpose(0,1)) - lA = lA.view(prev_size) - else: - self.masked_beta_lower = self.masked_beta_upper = self.masked_beta = self.beta * self.beta_mask - - ############################ - # sparse_coo version for history coeffs - if self.history_beta_used: - # history_compute_time = time.time() - if beta_for_intermediate_layers: - # print(f'history intermediate beta for {start_node.name} with beta shape {self.history_intermediate_betas[start_node.name]["ub"].size()}') - if uA is not None: - # The beta for start_node has shape ([batch, prod(start_node.shape), n_max_history_beta]) - history_intermediate_beta = self.history_intermediate_betas[start_node.name]['ub'] - history_intermediate_beta = history_intermediate_beta.view( - history_intermediate_beta.size(0), -1, history_intermediate_beta.size(-1)) - if unstable_idx is not None: - # Only unstable neurons of the start_node neurons are used. - history_intermediate_beta = self.non_deter_index_select(history_intermediate_beta, index=unstable_idx, dim=1) - # new_history_coeffs has shape (batch, prod(nodes), n_max_history_beta) - # new_history_c has shape (batch, n_max_history_beta) - # This can generate a quite large matrix in shape (batch, prod(start_node.shape), prod(nodes)) which is the same size as uA and lA. - self.masked_beta_upper = torch.bmm(history_intermediate_beta, ( - self.new_history_coeffs * self.new_history_c.unsqueeze(1)).transpose(-1, - -2)) - if lA is not None: - history_intermediate_beta = self.history_intermediate_betas[start_node.name]['lb'] - history_intermediate_beta = history_intermediate_beta.view( - history_intermediate_beta.size(0), -1, history_intermediate_beta.size(-1)) - if unstable_idx is not None: - # Only unstable neurons of the start_node neurons are used. - history_intermediate_beta = self.non_deter_index_select(history_intermediate_beta, index=unstable_idx, dim=1) - self.masked_beta_lower = torch.bmm(history_intermediate_beta, ( - self.new_history_coeffs * self.new_history_c.unsqueeze(1)).transpose(-1, - -2)) - else: - # new_history_coeffs has shape (batch, prod(nodes), n_max_history_beta) - # new_history_beta has shape (batch, m_max_history_beta) - self.masked_beta_lower = self.masked_beta_upper = torch.bmm(self.new_history_coeffs, ( - self.new_history_beta * self.new_history_c).unsqueeze(-1)).squeeze(-1) - - # new split constraint - if self.split_beta_used: - split_convert_time = time.time() - if self.split_coeffs["dense"] is None: - assert not hasattr(self, 'split_intermediate_betas') # intermediate beta split must use the dense mode. - ##### we can use repeat to further save the conversion time - # since the new split constraint coeffs can be optimized, we can just save the index and assign optimized coeffs value to the sparse matrix - self.new_split_coeffs = torch.zeros(self.split_c.size(0), self.flattened_nodes, - dtype=torch.get_default_dtype(), device=device) - # assign coeffs value to the first half batch - self.new_split_coeffs[ - (self.split_coeffs["nonzero"][:, 0], self.split_coeffs["nonzero"][:, 1])] = \ - self.split_coeffs["coeffs"] - # # assign coeffs value to the rest half batch with the same values since split constraint shared the same coeffs for >0/<0 - self.new_split_coeffs[(self.split_coeffs["nonzero"][:, 0] + int(self.split_c.size(0) / 2), - self.split_coeffs["nonzero"][:, 1])] = self.split_coeffs["coeffs"] - else: - # batch = int(self.split_c.size(0)/2) - # assign coeffs value to the first half batch and the second half batch - self.new_split_coeffs = self.split_coeffs["dense"].repeat(2, 1) - split_convert_time = time.time() - split_convert_time - split_compute_time = time.time() - if beta_for_intermediate_layers: - assert hasattr(self, 'split_intermediate_betas') - # print(f'split intermediate beta for {start_node.name} with beta shape {self.split_intermediate_betas[start_node.name]["ub"].size()}') - if uA is not None: - # upper bound betas for this set of intermediate neurons. - # Make an extra spec dimension. Now new_split_coeffs has size (batch, specs, #nodes). Specs is the number of intermediate neurons of start node. The same split will be applied to all specs in a batch element. - # masked_beta_upper has shape (batch, spec, #nodes) - split_intermediate_betas = self.split_intermediate_betas[start_node.name]['ub'] - split_intermediate_betas = split_intermediate_betas.view(split_intermediate_betas.size(0), -1, split_intermediate_betas.size(-1)) - if unstable_idx is not None: - # Only unstable neurons of the start_node neurons are used. - split_intermediate_betas = self.non_deter_index_select(split_intermediate_betas, index=unstable_idx, dim=1) - self.split_masked_beta_upper = split_intermediate_betas * ( - self.new_split_coeffs * self.split_c).unsqueeze(1) - if lA is not None: - split_intermediate_betas = self.split_intermediate_betas[start_node.name]['lb'] - split_intermediate_betas = split_intermediate_betas.view(split_intermediate_betas.size(0), -1, split_intermediate_betas.size(-1)) - if unstable_idx is not None: - # Only unstable neurons of the start_node neurons are used. - split_intermediate_betas = self.non_deter_index_select(split_intermediate_betas, index=unstable_idx, dim=1) - self.split_masked_beta_lower = split_intermediate_betas * ( - self.new_split_coeffs * self.split_c).unsqueeze(1) - else: - # beta for final objective only. TODO: distinguish between lb and ub. - self.split_masked_beta_upper = self.split_masked_beta_lower = self.new_split_coeffs * ( - self.split_beta * self.split_c) - # add the new split constraint beta to the masked_beta - if self.masked_beta_upper is None: - self.masked_beta_upper = self.split_masked_beta_upper - else: - self.masked_beta_upper = self.masked_beta_upper + self.split_masked_beta_upper + # For conv layer, the last dimension is flattened in indices. + prev_size = A.size() + if self.alpha_beta_update_mask is not None: + beta_indices = beta_indices[:, self.alpha_beta_update_mask] + beta_values = beta_values[:, self.alpha_beta_update_mask] + if uA is not None: + uA = self.non_deter_scatter_add(uA.view(uA.size(0), uA.size(1), -1), dim=2, index=beta_indices, src=beta_values.to(uA.dtype)) + uA = uA.view(prev_size) + if lA is not None: + lA = self.non_deter_scatter_add(lA.view(lA.size(0), lA.size(1), -1), dim=2, index=beta_indices, src=beta_values.neg().to(lA.dtype)) + lA = lA.view(prev_size) + else: + raise RuntimeError(f"Unknown type {type(A)} for A") + return lA, uA - if self.masked_beta_lower is None: - self.masked_beta_lower = self.split_masked_beta_lower - else: - self.masked_beta_lower = self.masked_beta_lower + self.split_masked_beta_lower - # For backwards compatibility - we originally only have one beta. - self.masked_beta = self.masked_beta_lower - split_compute_time = time.time() - split_compute_time - - A = last_uA if last_uA is not None else last_lA - if type(A) is Patches: - assert not hasattr(self, 'split_intermediate_betas') - assert not hasattr(self, 'single_intermediate_betas') - A_patches = A.patches - # Reshape beta to image size. - self.masked_beta = self.masked_beta.view(self.masked_beta.size(0), *ub_r.size()[1:]) - # unfold the beta as patches, size (batch, out_h, out_w, in_c, H, W) - masked_beta_unfolded = inplace_unfold(self.masked_beta, kernel_size=A_patches.shape[-2:], padding=A.padding, stride=A.stride) - if A.unstable_idx is not None: - masked_beta_unfolded = masked_beta_unfolded.permute(1, 2, 0, 3, 4) - # After selection, the shape is (unstable_size, batch, in_c, H, W). - masked_beta_unfolded = masked_beta_unfolded[A.unstable_idx[1], A.unstable_idx[2]] - else: - # Add the spec (out_c) dimension. - masked_beta_unfolded = masked_beta_unfolded.unsqueeze(0) - if uA is not None: - uA = Patches(uA.patches + masked_beta_unfolded, uA.stride, uA.padding, uA.patches.shape, unstable_idx=uA.unstable_idx, output_shape=uA.output_shape) - if lA is not None: - lA = Patches(lA.patches - masked_beta_unfolded, lA.stride, lA.padding, lA.patches.shape, unstable_idx=lA.unstable_idx, output_shape=lA.output_shape) - elif type(A) is torch.Tensor: - if uA is not None: - # print("uA", uA.shape, self.masked_beta.shape) - # uA/lA has shape (spec, batch, *nodes) - if beta_for_intermediate_layers: - if not self.single_beta_used: - # masked_beta_upper has shape (batch, spec, #nodes) - self.masked_beta_upper = self.masked_beta_upper.transpose(0, 1) - self.masked_beta_upper = self.masked_beta_upper.view(self.masked_beta_upper.size(0), - self.masked_beta_upper.size(1), - *uA.shape[2:]) - else: - # masked_beta_upper has shape (batch, #nodes) - self.masked_beta_upper = self.masked_beta_upper.reshape(uA[0].shape).unsqueeze(0) - if not self.single_beta_used or not beta_for_intermediate_layers: - # For intermediate layer betas witn single node split, uA has been modified above. - uA = uA + self.masked_beta_upper - if lA is not None: - # print("lA", lA.shape, self.masked_beta.shape) - if beta_for_intermediate_layers: - if not self.single_beta_used: - # masked_beta_upper has shape (batch, spec, #nodes) - self.masked_beta_lower = self.masked_beta_lower.transpose(0, 1) - self.masked_beta_lower = self.masked_beta_lower.view(self.masked_beta_lower.size(0), - self.masked_beta_lower.size(1), - *lA.shape[2:]) - else: - # masked_beta_upper has shape (batch, #nodes) - self.masked_beta_lower = self.masked_beta_lower.reshape(lA[0].shape).unsqueeze(0) - if not self.single_beta_used or not beta_for_intermediate_layers: - # For intermediate layer betas witn single node split, lA has been modified above. - lA = lA - self.masked_beta_lower - else: - raise RuntimeError(f"Unknown type {type(A)} for A") - # print("total:", time.time()-start_time, history_compute_time1, history_compute_time2, split_convert_time, split_compute_time) + if self.cut_used: + # propagate prerelu node in cut constraints + lA, uA = self.cut_module.pre_cut(start_node, self.name, lA, uA, current_layer_shape, unstable_idx, + batch_mask=self.alpha_beta_update_mask) + self.masked_beta_lower = self.masked_beta_upper = None + if self.options.get('optimize_bound_args', {}).get('enable_beta_crown', False) and self.sparse_beta is not None: + if self.options.get('optimize_bound_args', {}).get('single_node_split', False): + # Beta-CROWN: each split constraint only has single neuron (e.g., second ReLU neuron > 0). + A = lA if lA is not None else uA + lA, uA = _beta_crown_single_neuron_splits(A, uA, lA, unstable_idx) + # The code block below is for debugging and will be removed (until the end of this function). + # elif False and not self.options.get('optimize_bound_args', {}).get('single_node_split', True): + # # Improved Beta-CROWN: (1) general split constraints: each split constraint have multiple neuron + # # (e.g., second ReLU neuron > 0); (2) intermediate Relu bounds refinement with the general split constraints. + # A = uA if uA is not None else lA + # lA, uA = _beta_crown_multi_neuron_splits(x, A, uA, lA, unstable_idx, start_node) + # print(lA.sum(), uA.sum()) + # exit() return [(lA, uA)], lbias, ubias def interval_propagate(self, *v): - if Interval.use_relative_bounds(*v): - nominal = F.relu(v[0].nominal) - mask_nominal = (nominal > 0).float() - mask_l = (v[0].lower > 0).float() - mask_u = (v[0].upper > 0).float() - lower_offset = mask_nominal * (mask_l * v[0].lower_offset + (1 - mask_l) * (-nominal)) - upper_offset = mask_nominal * v[0].upper_offset + (1 - mask_nominal) * mask_u * v[0].upper - return Interval(None, None, nominal, lower_offset, upper_offset) - h_L, h_U = v[0][0], v[0][1] - return F.relu(h_L), F.relu(h_U) - def bound_forward(self, dim_in, x): - return super().bound_forward(dim_in, x) - -class BoundSqrt(BoundActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.nonlinear = True - - @Bound.save_io_shape - def forward(self, x): - return torch.sqrt(x) - - def interval_propagate(self, *v): - if Interval.use_relative_bounds(*v): - nominal = self.forward(v[0].nominal) - lower_offset = self.forward(v[0].nominal + v[0].lower_offset) - nominal - upper_offset = self.forward(v[0].nominal + v[0].upper_offset) - nominal - return Interval(None, None, nominal, lower_offset, upper_offset) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + # e.g., last layer input gurobi vars (8,16,16) + gvars_array = np.array(v[0]) + this_layer_shape = gvars_array.shape + assert gvars_array.shape == self.output_shape[1:] + + pre_lbs = self.inputs[0].lower.cpu().detach().numpy().reshape(-1) + pre_ubs = self.inputs[0].upper.cpu().detach().numpy().reshape(-1) + + new_layer_gurobi_vars = [] + relu_integer_vars = [] + new_relu_layer_constrs = [] + # predefined zero variable shared in the whole solver model + zero_var = model.getVarByName("zero") + + for neuron_idx, pre_var in enumerate(gvars_array.reshape(-1)): + pre_ub = pre_ubs[neuron_idx] + pre_lb = pre_lbs[neuron_idx] + + if pre_lb >= 0: + # ReLU is always passing + var = pre_var + elif pre_ub <= 0: + var = zero_var + else: + ub = pre_ub + + var = model.addVar(ub=ub, lb=pre_lb, + obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'ReLU{self.name}_{neuron_idx}') + + if model_type == "mip" or model_type == "lp_integer": + # binary indicator + if model_type == "mip": + a = model.addVar(vtype=grb.GRB.BINARY, name=f'aReLU{self.name}_{neuron_idx}') + elif model_type == "lp_integer": + a = model.addVar(ub=1, lb=0, vtype=grb.GRB.CONTINUOUS, name=f'aReLU{self.name}_{neuron_idx}') + relu_integer_vars.append(a) + + new_relu_layer_constrs.append( + model.addConstr(pre_var - pre_lb * (1 - a) >= var, + name=f'ReLU{self.name}_{neuron_idx}_a_0')) + new_relu_layer_constrs.append( + model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_1')) + new_relu_layer_constrs.append( + model.addConstr(pre_ub * a >= var, name=f'ReLU{self.name}_{neuron_idx}_a_2')) + new_relu_layer_constrs.append( + model.addConstr(var >= 0, name=f'ReLU{self.name}_{neuron_idx}_a_3')) + + elif model_type == "lp": + new_relu_layer_constrs.append( + model.addConstr(var >= 0, name=f'ReLU{self.name}_{neuron_idx}_a_0')) + new_relu_layer_constrs.append( + model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_1')) + new_relu_layer_constrs.append(model.addConstr( + pre_ub * pre_var - (pre_ub - pre_lb) * var >= pre_ub * pre_lb, + name=f'ReLU{self.name}_{neuron_idx}_a_2')) - return super().interval_propagate(*v) + else: + print(f"gurobi model type {model_type} not supported!") -class BoundReciprocal(BoundActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.nonlinear = True + new_layer_gurobi_vars.append(var) - @Bound.save_io_shape - def forward(self, x): - return torch.reciprocal(x) + new_layer_gurobi_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape).tolist() + if model_type in ["mip", "lp_integer"]: + self.integer_vars = relu_integer_vars + self.solver_vars = new_layer_gurobi_vars + self.solver_constrs = new_relu_layer_constrs + model.update() - def bound_relax(self, x): - m = (x.lower + x.upper) / 2 - kl = -1 / m.pow(2) - self._add_linear(mask=None, type='lower', k=kl, x0=m, y0=1. / m) - ku = -1. / (x.lower * x.upper) - self._add_linear(mask=None, type='upper', k=ku, x0=x.lower, y0=1. / x.lower) + def dump_optimized_params(self): + return { + 'alpha': self.alpha, + 'alpha_lookup_idx': self.alpha_lookup_idx, + 'alpha_indices': self.alpha_indices + } - def interval_propagate(self, *v): - h_L, h_U = v[0][0].float(), v[0][1].float() - assert h_L.min() > 0, 'Only positive values are supported in BoundReciprocal' - return torch.reciprocal(h_U), torch.reciprocal(h_L) + def restore_optimized_params(self, opt_var_dict): + self.alpha, self.alpha_lookup_idx, self.alpha_indices = \ + opt_var_dict['alpha'], opt_var_dict['alpha_lookup_idx'], opt_var_dict['alpha_indices'] -class BoundSin(BoundActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.max_point = math.pi / 2 - self.min_point = math.pi * 3 / 2 +class BoundLeakyRelu(BoundActivation): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.options = options.get('relu') + self.alpha = attr['alpha'] - @Bound.save_io_shape def forward(self, x): - return torch.sin(x) - - def interval_propagate(self, *v): - # Check if a point is in [l, u], considering the 2pi period - def check_crossing(ll, uu, point): - return ((((uu - point) / (2 * math.pi)).floor() - ((ll - point) / (2 * math.pi)).floor()) > 0).float() - h_L, h_U = v[0][0], v[0][1] - h_Ls, h_Us = self.forward(h_L), self.forward(h_U) - # If crossing pi/2, then max is fixed 1.0 - max_mask = check_crossing(h_L, h_U, self.max_point) - # If crossing pi*3/2, then min is fixed -1.0 - min_mask = check_crossing(h_L, h_U, self.min_point) - ub = torch.max(h_Ls, h_Us) - ub = max_mask + (1 - max_mask) * ub - lb = torch.min(h_Ls, h_Us) - lb = - min_mask + (1 - min_mask) * lb - return lb, ub + return F.leaky_relu(x, negative_slope=self.alpha) - def bound_backward(self, last_lA, last_uA, *x, start_node=None, start_shape=None): - return not_implemented_op(self, 'bound_backward') + def bound_backward(self, last_lA, last_uA, x=None, start_node=None, start_shape=None): + if x is not None: + lb_r = x.lower.clamp(max=0) + ub_r = x.upper.clamp(min=0) + else: + lb_r = self.lower.clamp(max=0) + ub_r = self.upper.clamp(min=0) + ub_r = torch.max(ub_r, lb_r + 1e-8) + upper_d = (ub_r - self.alpha * lb_r) / (ub_r - lb_r) + upper_b = - lb_r * upper_d + self.alpha * lb_r + if self.options == "same-slope": + # the same slope for upper and lower + lower_d = upper_d + elif self.options == "zero-lb": + # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN + lower_d = (upper_d >= 1.0).to(upper_d.dtype) + (upper_d < 1.0).to(upper_d.dtype) * self.alpha + elif self.options == "one-lb": + # Always use slope 1 as lower bound + lower_d = (upper_d > 0.0).to(upper_d.dtype)+ (upper_d <= 0.0).to(upper_d.dtype) * self.alpha + else: + lower_d = (upper_d > 0.5).to(upper_d.dtype) + (upper_d <= 0.5).to(upper_d.dtype)* self.alpha -class BoundCos(BoundSin): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.max_point = 0.0 - self.min_point = math.pi + upper_d = upper_d.unsqueeze(0) + lower_d = lower_d.unsqueeze(0) + # Choose upper or lower bounds based on the sign of last_A + uA = lA = None + ubias = lbias = 0 + if last_uA is not None: + neg_uA = last_uA.clamp(max=0) + pos_uA = last_uA.clamp(min=0) + uA = upper_d * pos_uA + lower_d * neg_uA + ubias = self.get_bias(pos_uA, upper_b) + if last_lA is not None: + neg_lA = last_lA.clamp(max=0) + pos_lA = last_lA.clamp(min=0) + lA = upper_d * neg_lA + lower_d * pos_lA + lbias = self.get_bias(neg_lA, upper_b) + return [(lA, uA)], lbias, ubias - @Bound.save_io_shape - def forward(self, x): - return torch.cos(x) + def dump_optimized_params(self): + return self.alpha + def restore_optimized_params(self, alpha): + self.alpha = alpha class BoundTanh(BoundOptimizableActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.precompute_relaxation('tanh', torch.tanh, self.dtanh) + # Alpha dimension is (4, 2, output_shape, batch, *shape) for Tanh. + self.alpha_batch_dim = 3 def opt_init(self): super().opt_init() - self.tp_both_lower_init = {} + self.tp_both_lower_init = {} self.tp_both_upper_init = {} def init_opt_parameters(self, start_nodes): - self.alpha = OrderedDict() l, u = self.inputs[0].lower, self.inputs[0].upper shape = l.shape - for ns, size_s in start_nodes: + for ns, size_s, _ in start_nodes: + if isinstance(size_s, torch.Size): + size_s = prod(size_s) self.alpha[ns] = torch.empty(4, 2, size_s, *shape, device=l.device) self.alpha[ns].data[:2] = ((l + u) / 2).unsqueeze(0).expand(2, 2, size_s, *shape) self.alpha[ns].data[2] = self.tp_both_lower_init[ns].expand(2, size_s, *shape) @@ -863,15 +986,16 @@ def init_opt_parameters(self, start_nodes): def dtanh(self, x): # to avoid bp error when cosh is too large # cosh(25.0)**2 > 1e21 - mask = torch.lt(torch.abs(x), 25.0).float() + mask = torch.lt(torch.abs(x), 25.0).to(x.dtype) cosh = torch.cosh(mask * x + 1 - mask) return mask * (1. / cosh.pow(2)) - """Precompute relaxation parameters for tanh and sigmoid""" - @torch.no_grad() - def precompute_relaxation(self, name, func, dfunc): - self.x_limit = 500 + def precompute_relaxation(self, name, func, dfunc, x_limit = 500): + """ + This function precomputes the tangent lines that will be used as lower/upper bounds for S-shapes functions. + """ + self.x_limit = x_limit self.step_pre = 0.01 self.num_points_pre = int(self.x_limit / self.step_pre) max_iter = 100 @@ -879,35 +1003,49 @@ def precompute_relaxation(self, name, func, dfunc): logger.debug('Precomputing relaxation for {}'.format(name)) def check_lower(upper, d): + """Given two points upper, d (d <= upper), check if the slope at d will be less than f(upper) at upper.""" k = dfunc(d) + # Return True if the slope is a lower bound. return k * (upper - d) + func(d) <= func(upper) def check_upper(lower, d): + """Given two points lower, d (d >= lower), check if the slope at d will be greater than f(lower) at lower.""" k = dfunc(d) + # Return True if the slope is a upper bound. return k * (lower - d) + func(d) >= func(lower) + # Given an upper bound point (>=0), find a line that is guaranteed to be a lower bound of this function. upper = self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) r = torch.zeros_like(upper) + # Initial guess, the tangent line is at -1. l = -torch.ones_like(upper) while True: + # Check if the tangent line at the guessed point is an lower bound at f(upper). checked = check_lower(upper, l).int() + # If the initial guess is not smaller enough, then double it (-2, -4, etc). l = checked * l + (1 - checked) * (l * 2) - if checked.sum() == l.numel(): + if checked.sum() == l.numel(): break + # Now we have starting point at l, its tangent line is guaranteed to be an lower bound at f(upper). + # We want to further tighten this bound by moving it closer to 0. for t in range(max_iter): + # Binary search. m = (l + r) / 2 checked = check_lower(upper, m).int() l = checked * m + (1 - checked) * l r = checked * r + (1 - checked) * m + # At upper, a line with slope l is guaranteed to lower bound the function. self.d_lower = l.clone() + # Do the same again: + # Given an lower bound point (<=0), find a line that is guaranteed to be an upper bound of this function. lower = -self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) l = torch.zeros_like(upper) r = torch.ones_like(upper) while True: checked = check_upper(lower, r).int() r = checked * r + (1 - checked) * (r * 2) - if checked.sum() == l.numel(): + if checked.sum() == l.numel(): break for t in range(max_iter): m = (l + r) / 2 @@ -918,12 +1056,11 @@ def check_upper(lower, d): logger.debug('Done') - @Bound.save_io_shape def forward(self, x): return torch.tanh(x) def bound_relax_impl(self, x, func, dfunc): - # When self.x_limit is large enough, torch.tanh(self.x_limit)=1, + # When self.x_limit is large enough, torch.tanh(self.x_limit)=1, # and thus clipping is valid lower = x.lower.clamp(min=-self.x_limit) upper = x.upper.clamp(max=self.x_limit) @@ -931,101 +1068,134 @@ def bound_relax_impl(self, x, func, dfunc): min_preact = 1e-6 mask_close = (upper - lower) < min_preact - k_direct = k = torch.where(mask_close, - dfunc(upper), (y_u - y_l) / (upper - lower).clamp(min=min_preact)) - - # Fixed bounds that cannot be optimized - # upper bound for negative - self._add_linear(mask=self.mask_neg, type='upper', k=k, x0=lower, y0=y_l) - # lower bound for positive - self._add_linear(mask=self.mask_pos, type='lower', k=k, x0=lower, y0=y_l) - + # k_direct is the slope of the line directly connect (lower, func(lower)), (upper, func(upper)). + k_direct = k = torch.where(mask_close, + dfunc(upper), (y_u - y_l) / (upper - lower).clamp(min=min_preact)) + + # Fixed bounds that cannot be optimized. self.mask_neg are the masks for neurons with upper bound <= 0. + # Upper bound for the case of input lower bound <= 0, is always the direct line. + self.add_linear_relaxation(mask=self.mask_neg, type='upper', k=k, x0=lower, y0=y_l) + # Lower bound for the case of input upper bound >= 0, is always the direct line. + self.add_linear_relaxation(mask=self.mask_pos, type='lower', k=k, x0=lower, y0=y_l) + + # Indices of neurons with input upper bound >=0, whose optimal slope to lower bound the function was pre-computed. + # Note that for neurons with also input lower bound >=0, they will be masked later. index = torch.max( torch.zeros(upper.numel(), dtype=torch.long, device=upper.device), (upper / self.step_pre).to(torch.long).reshape(-1) ) + 1 + # Lookup the lower bound slope from the pre-computed table. d_lower = torch.index_select(self.d_lower, 0, index).view(lower.shape) + # Indices of neurons with lower bound <=0, whose optimal slope to upper bound the function was pre-computed. index = torch.max( torch.zeros(lower.numel(), dtype=torch.long, device=lower.device), (lower / -self.step_pre).to(torch.long).reshape(-1) ) + 1 - d_upper = torch.index_select(self.d_upper, 0, index).view(upper.shape) + d_upper = torch.index_select(self.d_upper, 0, index).view(upper.shape) - ns = self._start.name - - # bound with tangent lines can be optimized - if self.opt_stage == 'opt': + if self.opt_stage in ['opt', 'reuse']: if not hasattr(self, 'alpha'): + # Raise an error if alpha is not created. self._no_bound_parameters() + ns = self._start # Clipping is done here rather than after `opt.step()` call - # because it depends on pre-activation bounds + # because it depends on pre-activation bounds self.alpha[ns].data[0, :] = torch.max(torch.min(self.alpha[ns][0, :], upper), lower) self.alpha[ns].data[1, :] = torch.max(torch.min(self.alpha[ns][1, :], upper), lower) self.alpha[ns].data[2, :] = torch.min(self.alpha[ns][2, :], d_lower) self.alpha[ns].data[3, :] = torch.max(self.alpha[ns][3, :], d_upper) + # shape [2, out_c, n, c, h, w]. tp_pos = self.alpha[ns][0] tp_neg = self.alpha[ns][1] tp_both_lower = self.alpha[ns][2] - tp_both_upper = self.alpha[ns][3] + tp_both_upper = self.alpha[ns][3] # No need to use tangent line, when the tangent point is at the left # side of the preactivation lower bound. Simply connect the two sides. - mask_direct = self.mask_both * ( k_direct < dfunc(lower) ) - self._add_linear(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) - self._add_linear(mask=self.mask_both - mask_direct, type='lower', - k=dfunc(tp_both_lower), x0=tp_both_lower, + mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(lower)) + self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), type='lower', + k=dfunc(tp_both_lower), x0=tp_both_lower, y0=self.forward(tp_both_lower)) - mask_direct = self.mask_both * ( k_direct < dfunc(upper) ) - self._add_linear(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) - self._add_linear(mask=self.mask_both - mask_direct, type='upper', - k=dfunc(tp_both_upper), x0=tp_both_upper, + mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(upper)) + self.add_linear_relaxation(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), type='upper', + k=dfunc(tp_both_upper), x0=tp_both_upper, y0=self.forward(tp_both_upper)) - self._add_linear(mask=self.mask_neg, type='lower', + self.add_linear_relaxation( + mask=self.mask_neg, type='lower', k=dfunc(tp_neg), x0=tp_neg, y0=self.forward(tp_neg)) - self._add_linear(mask=self.mask_pos, type='upper', + self.add_linear_relaxation( + mask=self.mask_pos, type='upper', k=dfunc(tp_pos), x0=tp_pos, y0=self.forward(tp_pos)) else: + # Not optimized (vanilla CROWN bound). + # Use the middle point slope as the lower/upper bound. Not optimized. m = (lower + upper) / 2 y_m = func(m) k = dfunc(m) - # lower bound for negative - self._add_linear(mask=self.mask_neg, type='lower', k=k, x0=m, y0=y_m) - # upper bound for positive - self._add_linear(mask=self.mask_pos, type='upper', k=k, x0=m, y0=y_m) - + # Lower bound is the middle point slope for the case input upper bound <= 0. + # Note that the upper bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). + self.add_linear_relaxation(mask=self.mask_neg, type='lower', k=k, x0=m, y0=y_m) + # Upper bound is the middle point slope for the case input lower bound >= 0. + # Note that the lower bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). + self.add_linear_relaxation(mask=self.mask_pos, type='upper', k=k, x0=m, y0=y_m) + + # Now handle the case where input lower bound <=0 and upper bound >= 0. + # A tangent line starting at d_lower is guaranteed to be a lower bound given the input upper bound. k = dfunc(d_lower) y0 = func(d_lower) if self.opt_stage == 'init': + # Initialize optimizable slope. + ns = self._start self.tp_both_lower_init[ns] = d_lower.detach() - mask_direct = self.mask_both * ( k_direct < dfunc(lower) ) - self._add_linear(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) - self._add_linear(mask=self.mask_both - mask_direct, type='lower', k=k, x0=d_lower, y0=y0) - + # Another possibility is to use the direct line as the lower bound, when this direct line does not intersect with f. + # This is only valid when the slope at the input lower bound has a slope greater than the direct line. + mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(lower)) + self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + # Otherwise we do not use the direct line, we use the d_lower slope. + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), + type='lower', k=k, x0=d_lower, y0=y0) + + # Do the same for the upper bound side when input lower bound <=0 and upper bound >= 0. k = dfunc(d_upper) y0 = func(d_upper) if self.opt_stage == 'init': - self.tp_both_upper_init[ns] = d_upper.detach() + ns = self._start + self.tp_both_upper_init[ns] = d_upper.detach() self.tmp_lower = x.lower.detach() self.tmp_upper = x.upper.detach() - mask_direct = self.mask_both * ( k_direct < dfunc(upper) ) - self._add_linear(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) - self._add_linear(mask=self.mask_both - mask_direct, type='upper', k=k, x0=d_upper, y0=y0) + mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(upper)) + self.add_linear_relaxation(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), + type='upper', k=k, x0=d_upper, y0=y0) def bound_relax(self, x): self.bound_relax_impl(x, torch.tanh, self.dtanh) + def dump_optimized_params(self): + return self.alpha + + def restore_optimized_params(self, alpha): + self.alpha = alpha + class BoundSigmoid(BoundTanh): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super(BoundTanh, self).__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super(BoundTanh, self).__init__(attr, inputs, output_index, options) self.precompute_relaxation('sigmoid', torch.sigmoid, self.dsigmoid) + # Alpha dimension is (4, 2, output_shape, batch, *shape) for S-shaped functions. + self.alpha_batch_dim = 3 - @Bound.save_io_shape def forward(self, x): return torch.sigmoid(x) @@ -1037,196 +1207,125 @@ def bound_relax(self, x): class BoundSoftplus(BoundActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super(BoundSoftplus, self).__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super(BoundSoftplus, self).__init__(attr, inputs, output_index, options) self.softplus = nn.Softplus() - @Bound.save_io_shape def forward(self, x): - return self.softplus(x) + return self.softplus(x) -class BoundExp(BoundActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.options = options.get('exp') - self.max_input = 0 +class BoundAbs(BoundActivation): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) - @Bound.save_io_shape def forward(self, x): - if self.loss_fusion and self.options != 'no-max-input': - self.max_input = torch.max(x, dim=-1, keepdim=True)[0].detach() - return torch.exp(x - self.max_input) - return torch.exp(x) - - def interval_propagate(self, *v): - assert (len(v) == 1) - - if Interval.use_relative_bounds(*v): - assert not self.loss_fusion or self.options == 'no-max-input' - nominal = torch.exp(v[0].nominal) - return Interval( - None, None, - nominal, - nominal * (torch.exp(v[0].lower_offset) - 1), - nominal * (torch.exp(v[0].upper_offset) - 1) - ) - - # unary monotonous functions only - h_L, h_U = v[0] - if self.loss_fusion and self.options != 'no-max-input': - self.max_input = torch.max(h_U, dim=-1, keepdim=True)[0] - h_L, h_U = h_L - self.max_input, h_U - self.max_input - else: - self.max_input = 0 - return torch.exp(h_L), torch.exp(h_U) - - def bound_forward(self, dim_in, x): - m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99) - - exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper) - - kl = exp_m - lw = x.lw * kl.unsqueeze(1) - lb = kl * (x.lb - m + 1) - - ku = (exp_u - exp_l) / (x.upper - x.lower + epsilon) - uw = x.uw * ku.unsqueeze(1) - ub = x.ub * ku - ku * x.lower + exp_l + return x.abs() - return LinearBound(lw, lb, uw, ub) - - def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None): - # Special case when computing log_softmax (FIXME: find a better solution, this trigger condition is not reliable). - if self.loss_fusion and last_lA is None and last_uA is not None and torch.min( - last_uA) >= 0 and x.from_input: - # Adding an extra bias term to the input. This is equivalent to adding a constant and subtract layer before exp. - # Note that we also need to adjust the bias term at the end. - if self.options == 'no-detach': - self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0] - elif self.options != 'no-max-input': - self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0].detach() + def bound_backward(self, last_lA, last_uA, x): + x_L = x.lower.clamp(max=0) + x_U = torch.max(x.upper.clamp(min=0), x_L + 1e-8) + mask_neg = x_U <= 0 + mask_pos = x_L >= 0 + y_L = x_L.abs() + y_U = x_U.abs() + upper_k = (y_U - y_L) / (x_U - x_L) + upper_b = y_L - upper_k * x_L + lower_k = (mask_neg * (-1.0) + mask_pos * 1.0) + lower_b = (mask_neg + mask_pos) * ( y_L - lower_k * x_L ) + if last_uA is not None: + # Special case if we only want the upper bound with non-negative coefficients + if last_uA.min() >= 0: + uA = last_uA * upper_k + ubias = self.get_bias(last_uA, upper_b) else: - self.max_input = 0 - adjusted_lower = x.lower - self.max_input - adjusted_upper = x.upper - self.max_input - # relaxation for upper bound only (used in loss fusion) - exp_l, exp_u = torch.exp(adjusted_lower), torch.exp(adjusted_upper) - k = (exp_u - exp_l) / (adjusted_upper - adjusted_lower + epsilon) - if k.requires_grad: - k = k.clamp(min=1e-6) - uA = last_uA * k.unsqueeze(0) - ubias = last_uA * (-adjusted_lower * k + exp_l).unsqueeze(0) - - if ubias.ndim > 2: - ubias = torch.sum(ubias, dim=tuple(range(2, ubias.ndim))) - # Also adjust the missing ubias term. - if uA.ndim > self.max_input.ndim: - A = torch.sum(uA, dim=tuple(range(self.max_input.ndim, uA.ndim))) + last_uA_pos = last_uA.clamp(min=0) + last_uA_neg = last_uA.clamp(max=0) + uA = last_uA_pos * upper_k + last_uA_neg * lower_k + ubias = (self.get_bias(last_uA_pos, upper_b) + + self.get_bias(last_uA_neg, lower_b)) + else: + uA, ubias = None, 0 + if last_lA is not None: + if last_lA.max() <= 0: + lA = last_lA * upper_k + lbias = self.get_bias(last_lA, upper_b) else: - A = uA - - # These should hold true in loss fusion - assert self.batch_dim == 0 - assert A.shape[0] == 1 - - batch_size = A.shape[1] - ubias -= (A.reshape(batch_size, -1) * self.max_input.reshape(batch_size, -1)).sum(dim=-1).unsqueeze(0) - return [(None, uA)], 0, ubias + last_lA_pos = last_lA.clamp(min=0) + last_lA_neg = last_lA.clamp(max=0) + lA = last_lA_pos * lower_k + last_lA_neg * upper_k + lbias = (self.get_bias(last_lA_pos, lower_b) + + self.get_bias(last_lA_neg, upper_b)) else: - return super().bound_backward(last_lA, last_uA, x) + lA, lbias = None, 0 + return [(lA, uA)], lbias, ubias - def bound_relax(self, x): - min_val = -1e9 - l, u = x.lower.clamp(min=min_val), x.upper.clamp(min=min_val) - m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99) - exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper) - k = exp_m - self._add_linear(mask=None, type='lower', k=k, x0=m, y0=exp_m) - min_val = -1e9 # to avoid (-inf)-(-inf) when both input.lower and input.upper are -inf - epsilon = 1e-20 - close = (u - l < epsilon).int() - k = close * exp_u + (1 - close) * (exp_u - exp_l) / (u - l + epsilon) - self._add_linear(mask=None, type='upper', k=k, x0=l, y0=exp_l) - - -class BoundLog(BoundActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.nonlinear = True - - @Bound.save_io_shape - def forward(self, x): - # NOTE adhoc implementation for loss fusion - if self.loss_fusion: - return torch.logsumexp(self.inputs[0].inputs[0].inputs[0].forward_value, dim=-1) - return torch.log(x.clamp(min=epsilon)) + def interval_propagate(self, *v): + h_L, h_U = v[0][0], v[0][1] + lower = ((h_U < 0) * h_U.abs() + (h_L > 0) * h_L.abs()) + upper = torch.max(h_L.abs(), h_U.abs()) + return lower, upper - def bound_relax(self, x): - rl, ru = self.forward(x.lower), self.forward(x.upper) - ku = (ru - rl) / (x.upper - x.lower + epsilon) - self._add_linear(mask=None, type='lower', k=ku, x0=x.lower, y0=rl) - m = (x.lower + x.upper) / 2 - k = torch.reciprocal(m) - rm = self.forward(m) - self._add_linear(mask=None, type='upper', k=k, x0=m, y0=rm) - def interval_propagate(self, *v): - # NOTE adhoc implementation for loss fusion - if self.loss_fusion: - par = self.inputs[0].inputs[0].inputs[0] - if Interval.use_relative_bounds(*v): - lower = torch.logsumexp(par.interval.nominal + par.interval.lower_offset, dim=-1) - upper = torch.logsumexp(par.interval.nominal + par.interval.upper_offset, dim=-1) - return Interval.make_interval(lower, upper, nominal=self.forward_value, use_relative=True) - else: - lower = torch.logsumexp(par.lower, dim=-1) - upper = torch.logsumexp(par.upper, dim=-1) - return lower, upper - return super().interval_propagate(*v) +class BoundATenHeaviside(BoundOptimizableActivation): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.alpha_batch_dim = 2 - def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None): - A, lbias, ubias = super().bound_backward(last_lA, last_uA, x) - # NOTE adhoc implementation for loss fusion - if self.loss_fusion: - assert A[0][0] is None - exp_module = self.inputs[0].inputs[0] - ubias = ubias + self.get_bias(A[0][1], exp_module.max_input.squeeze(-1)) - return A, lbias, ubias + def forward(self, *x): + self.input_shape = x[0].shape + # x[0]: input; x[1]: value when x == 0 + return torch.heaviside(x[0], x[1]) + def init_opt_parameters(self, start_nodes): + l = self.inputs[0].lower + for ns, size_s, _ in start_nodes: + self.alpha[ns] = torch.zeros_like(l).unsqueeze(0).repeat(2, *[1] * l.ndim).requires_grad_(True) -class BoundPow(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.nonlinear = True + def clip_alpha_(self): + for v in self.alpha.values(): + v.data = torch.clamp(v.data, 0., 1.) - @Bound.save_io_shape - def forward(self, x, y): - return torch.pow(x, y) + def bound_backward(self, last_lA, last_uA, *x, start_node=None, start_shape=None): + x = x[0] + if x is not None: + lb_r = x.lower + ub_r = x.upper + else: + lb_r = self.lower + ub_r = self.upper - def interval_propagate(self, *v): - assert not self.is_input_perturbed(1) - - if Interval.use_relative_bounds(*v): - exp = v[1].nominal - assert exp == int(exp) - exp = int(exp) - h_L = v[0].nominal + v[0].lower_offset - h_U = v[0].nominal + v[0].upper_offset - lower, upper = torch.pow(h_L, exp), torch.pow(h_U, exp) - if exp % 2 == 0: - lower, upper = torch.min(lower, upper), torch.max(lower, upper) - mask = 1 - ((h_L < 0) * (h_U > 0)).float() - lower = lower * mask - return Interval.make_interval(lower, upper, nominal=self.forward_value, use_relative=True) - - exp = v[1][0] - assert exp == int(exp) - exp = int(exp) - pl, pu = torch.pow(v[0][0], exp), torch.pow(v[0][1], exp) - if exp % 2 == 1: - return pl, pu + if self.opt_stage not in ['opt', 'reuse']: + # zero slope: + upper_d = torch.zeros_like(lb_r, device=lb_r.device, dtype=lb_r.dtype) + lower_d = torch.zeros_like(ub_r, device=ub_r.device, dtype=ub_r.dtype) else: - pl, pu = torch.min(pl, pu), torch.max(pl, pu) - mask = 1 - ((v[0][0] < 0) * (v[0][1] > 0)).float() - return pl * mask, pu + upper_d = self.alpha[start_node.name][0].clamp(0, 1) * (1. / (-lb_r).clamp(min=1e-3)) + lower_d = self.alpha[start_node.name][1].clamp(0, 1) * (1. / (ub_r.clamp(min=1e-3))) + + upper_b = torch.ones_like(lb_r, device=lb_r.device, dtype=lb_r.dtype) + lower_b = torch.zeros_like(lb_r, device=lb_r.device, dtype=lb_r.dtype) + # For stable neurons, set fixed slope and bias. + ub_mask = (ub_r <= 0).to(dtype=ub_r.dtype) + lb_mask = (lb_r >= 0).to(dtype=lb_r.dtype) + upper_b = upper_b - upper_b * ub_mask + lower_b = lower_b * (1. - lb_mask) + lb_mask + upper_d = upper_d - upper_d * ub_mask - upper_d * lb_mask + lower_d = lower_d - lower_d * lb_mask - lower_d * ub_mask + upper_d = upper_d.unsqueeze(0) + lower_d = lower_d.unsqueeze(0) + # Choose upper or lower bounds based on the sign of last_A + uA = lA = None + ubias = lbias = 0 + if last_uA is not None: + neg_uA = last_uA.clamp(max=0) + pos_uA = last_uA.clamp(min=0) + uA = upper_d * pos_uA + lower_d * neg_uA + ubias = (pos_uA * upper_b + neg_uA * lower_b).flatten(2).sum(-1) + if last_lA is not None: + neg_lA = last_lA.clamp(max=0) + pos_lA = last_lA.clamp(min=0) + lA = upper_d * neg_lA + lower_d * pos_lA + lbias = (pos_lA * lower_b + neg_lA * upper_b).flatten(2).sum(-1) + + return [(lA, uA), (None, None)], lbias, ubias diff --git a/auto_LiRPA/operators/base.py b/auto_LiRPA/operators/base.py index 2062116..b16d813 100644 --- a/auto_LiRPA/operators/base.py +++ b/auto_LiRPA/operators/base.py @@ -3,22 +3,27 @@ import os import time import math +import warnings import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor import numpy as np from itertools import chain from numpy.lib.arraysetops import isin from collections import OrderedDict -from auto_LiRPA.perturbations import * -from auto_LiRPA.utils import * +from ..perturbations import * +from ..utils import * +from ..patches import * +from ..linear_bound import LinearBound torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) epsilon = 1e-12 + def not_implemented_op(node, func): message = ("Function `{}` of `{}` is not supported yet." " Please help to open an issue at https://github.com/KaidiXu/auto_LiRPA" @@ -26,26 +31,21 @@ def not_implemented_op(node, func): " or auto_LiRPA/operators by yourself.".format(func, node)) raise NotImplementedError(message) -"""Interval object. Used for interval bound propagation.""" + class Interval(tuple): + """Interval object for interval bound propagation.""" + # Subclassing tuple object so that all previous code can be reused. - def __new__(self, lb=None, ub=None, nominal=None, lower_offset=None, upper_offset=None, ptb=None): + def __new__(self, lb=None, ub=None, ptb=None): return tuple.__new__(Interval, (lb, ub)) - def __init__(self, lb, ub, nominal=None, lower_offset=None, upper_offset=None, ptb=None): - self.nominal = nominal - self.lower_offset = lower_offset - self.upper_offset = upper_offset - + def __init__(self, lb, ub, ptb=None): if ptb is None: self.ptb = None - # If relative bounds are not used, `self.ptb == None` means that this interval + # `self.ptb == None` means that this interval # is not perturbed and it shall be treated as a constant and lb = ub. - # But if relative bounds are used, every node in IBP is supposed to have an `Interval` object - # even if this node is perturbed. - if nominal is None: - # To avoid mistakes, in this case the caller must make sure lb and ub are the same object. - assert lb is ub + # To avoid mistakes, in this case the caller must make sure lb and ub are the same object. + assert lb is ub else: if not isinstance(ptb, Perturbation): raise ValueError("ptb must be a Perturbation object or None. Got type {}".format(type(ptb))) @@ -58,34 +58,17 @@ def __str__(self): def __repr__(self): return "Interval(lb={}, ub={}, ptb={})".format(self[0], self[1], self.ptb) - @property - def lower(self): - return self.nominal + self.lower_offset - - @property - def upper(self): - return self.nominal + self.upper_offset - - """Checking if the other interval is tuple, keep the perturbation.""" - @staticmethod - def make_interval(lb, ub, other=None, nominal=None, use_relative=False): + def make_interval(lb, ub, other=None): + """Checking if the other interval is tuple, keep the perturbation.""" if isinstance(other, Interval): return Interval(lb, ub, ptb=other.ptb) else: - if use_relative: - if nominal is None: - return Interval( - None, None, (lb + ub) / 2, (lb - ub) / 2, (ub - lb) / 2) - else: - return Interval(None, None, nominal, lb - nominal, ub - nominal) - else: - return (lb, ub) - - """Given a tuple or Interval object, returns the norm and eps.""" + return (lb, ub) @staticmethod def get_perturbation(interval): + """Given a tuple or Interval object, returns the norm and eps.""" if isinstance(interval, Interval) and interval.ptb is not None: if isinstance(interval.ptb, PerturbationLpNorm): return interval.ptb.norm, interval.ptb.eps @@ -93,33 +76,21 @@ def get_perturbation(interval): return np.inf, 1.0 elif isinstance(interval.ptb, PerturbationL0Norm): return 0, interval.ptb.eps, interval.ptb.ratio - # elif interval.ptb is None: - # raise RuntimeError("get_perturbation() encountered an interval that is not perturbed.") else: raise RuntimeError("get_perturbation() does not know how to handle {}".format(type(interval.ptb))) else: # Tuple object. Assuming L infinity norm lower and upper bounds. return np.inf, np.nan - """Checking if a Interval or tuple object has perturbation enabled.""" @staticmethod def is_perturbed(interval): + """Checking if a Interval or tuple object has perturbation enabled.""" if isinstance(interval, Interval) and interval.ptb is None: return False else: return True - @staticmethod - def use_relative_bounds(*intervals): - using = True - for interval in intervals: - using = using and ( - isinstance(interval, Interval) and - interval.nominal is not None and - interval.lower_offset is not None and interval.upper_offset is not None) - return using - class Bound(nn.Module): r""" @@ -127,12 +98,6 @@ class Bound(nn.Module): at `auto_LiRPA/operators`. Args: - input_name (list): The name of input nodes. - - name (str): The name of this node. - - ori_name (str): Name in the original model. - attr (dict): Attributes of the operator. inputs (list): A list of input nodes. @@ -141,18 +106,22 @@ class Bound(nn.Module): options (dict): Bound options. - device (str or torch.device): Device of the bounded module. - - Be sure to run `super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device)` + Be sure to run `super().__init__(attr, inputs, output_index, options, device)` first in the `__init__` function. """ - def __init__(self, input_name, name, ori_name, attr={}, inputs=[], output_index=0, options={}, device=None): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__() + attr = {} if attr is None else attr + inputs = [] if inputs is None else inputs + options = {} if options is None else options + self.name = None self.output_name = [] - self.input_name, self.name, self.ori_name, self.attr, self.inputs, self.output_index, self.options, self.device = \ - input_name, name, ori_name, attr, inputs, output_index, options, device + self.device = attr.get('device') + self.attr, self.inputs, self.output_index, self.options = \ + attr, inputs, output_index, options self.forward_value = None + self.output_shape = None self.from_input = False self.bounded = False self.IBP_rets = None @@ -164,7 +133,7 @@ def __init__(self, input_name, name, ori_name, attr={}, inputs=[], output_index= self.loss_fusion = False self.options = options # Use `default_interval_propagate` - self.use_default_ibp = False + self.use_default_ibp = False # If set to true, the backward bound output of this node is 0. self.zero_backward_coeffs_l = False self.zero_backward_coeffs_u = False @@ -172,15 +141,24 @@ def __init__(self, input_name, name, ori_name, attr={}, inputs=[], output_index= self.zero_lA_mtx = False self.zero_uA_mtx = False - """Check if the i-th input is with perturbation or not.""" + self.patches_start = False + + def __repr__(self): + return f'{self.__class__.__name__}(name="{self.name}")' + def is_input_perturbed(self, i=0): - return self.inputs[i].perturbed + r"""Check if the i-th input is with perturbation or not.""" + return i < len(self.inputs) and self.inputs[i].perturbed + + def clear(self): + """ Clear attributes when there is a new input to the network""" + pass def forward(self, *x): r""" - Function for standard/clean forward. + Function for standard/clean forward. - Args: + Args: x: A list of input values. The length of the list is equal to the number of input nodes. Returns: @@ -192,37 +170,30 @@ def interval_propagate(self, *v): r""" Function for interval bound propagation (IBP) computation. - There is a default function `self.default_interval_propagate(*v)` in the base class, + There is a default function `self.default_interval_propagate(*v)` in the base class, which can be used if the operator is *monotonic*. To use it, set `self.use_default_ibp = True` in the `__init__` function, and the implementation of this function can be skipped. - Args: - v: A list of the interval bound of input nodes. + Args: + v: A list of the interval bound of input nodes. Generally, for each element `v[i]`, `v[i][0]` is the lower interval bound, and `v[i][1]` is the upper interval bound. Returns: bound: The interval bound of this node, in a same format as v[i]. - """ + """ if self.use_default_ibp: return self.default_interval_propagate(*v) else: return not_implemented_op(self, 'interval_propagate') - - """For unary monotonous functions or functions for altering shapes only but not values""" + def default_interval_propagate(self, *v): + """For unary monotonous functions or functions for altering shapes only but not values""" if len(v) == 0: return Interval.make_interval(self.forward(), self.forward()) elif len(v) == 1: - if Interval.use_relative_bounds(v[0]): - return Interval( - None, None, - self.forward(v[0].nominal), - self.forward(v[0].lower_offset), - self.forward(v[0].upper_offset) - ) - else: - return Interval.make_interval(self.forward(v[0][0]), self.forward(v[0][1]), v[0]) + return Interval.make_interval( + self.forward(v[0][0]), self.forward(v[0][1]), v[0]) else: raise NotImplementedError('default_interval_propagate only supports no more than 1 input node') @@ -230,34 +201,41 @@ def default_interval_propagate(self, *v): def bound_forward(self, dim_in, *x): r""" Function for forward mode bound propagation. Forward mode LiRPA computs a `LinearBound` - instance representing the linear bound for each involved node. Major attributes of `LinearBound` include + instance representing the linear bound for each involved node. + Major attributes of `LinearBound` include `lw`, `uw`, `lb`, `ub`, `lower`, and `upper`. - `lw` and `uw` are coefficients of linear bounds w.r.t. model input. - Their shape is `(batch_size, dim_in, *standard_shape)`, where `dim_in` is the total dimension - of perturbed input nodes of the model, and `standard_shape` is the shape of the standard/clean output. - `lb` and `ub` are bias terms of linear bounds, and their shape is equal to the shape of standard/clean output. - `lower` and `upper` are concretized lower and upper bounds that will be computed later in BoundedModule. + `lw` and `uw` are coefficients of linear bounds w.r.t. model input. + Their shape is `(batch_size, dim_in, *standard_shape)`, + where `dim_in` is the total dimension of perturbed input nodes of the model, + and `standard_shape` is the shape of the standard/clean output. + `lb` and `ub` are bias terms of linear bounds, and their shape is equal + to the shape of standard/clean output. + `lower` and `upper` are concretized lower and upper bounds that will be + computed later in BoundedModule. - Args: + Args: dim_in (int): Total dimension of perturbed input nodes of the model. - + x: A list of the linear bound of input nodes. Each element in x is a `LinearBound` instance. Returns: bound (LinearBound): The linear bound of this node. - """ + """ return not_implemented_op(self, 'bound_forward') + def bound_dynamic_forward(self, *x, max_dim=None, offset=0): + raise NotImplementedError(f'bound_dynamic_forward is not implemented for {self}.') + def bound_backward(self, last_lA, last_uA, *x): r""" Function for backward mode bound propagation. - Args: + Args: last_lA (Tensor): `A` matrix for lower bound computation propagated to this node. It can be `None` if lower bound is not needed. - + last_uA (Tensor): `A` matrix for upper bound computation propagated to this node. It can be `None` if upper bound is not needed. - + x: A list of input nodes, with x[i].lower and x[i].upper that can be used as pre-activation bounds. Returns: @@ -265,8 +243,8 @@ def bound_backward(self, last_lA, last_uA, *x): lbias (Tensor): The bias term for lower bound computation, introduced by the linear relaxation of this node. . - ubias (Tensor): The bias term for upper bound computation, introduced by the linear relaxation of this node. - """ + ubias (Tensor): The bias term for upper bound computation, introduced by the linear relaxation of this node. + """ return not_implemented_op(self, 'bound_backward') def infer_batch_dim(self, batch_size, *x): @@ -278,8 +256,8 @@ def infer_batch_dim(self, batch_size, *x): def broadcast_backward(self, A, x): shape = x.output_shape batch_dim = max(self.batch_dim, 0) - - if isinstance(A, torch.Tensor): + + if isinstance(A, Tensor): if x.batch_dim == -1: # final shape of input shape = torch.Size([A.shape[batch_dim + 1]] + list(shape)) @@ -298,6 +276,7 @@ def broadcast_backward(self, A, x): dims = [] for i in range(len(shape)): # Skip the batch dimension. + # FIXME (05/11/2022): the following condition is not always correct. We should not rely on checking dimension is "1" or not. if shape[i] == 1 and A.shape[i + 1] != 1 and i != batch_dim: dims.append(i + 1) if dims: @@ -307,30 +286,6 @@ def broadcast_backward(self, A, x): pass return A - @staticmethod - def broadcast_forward(dim_in, x, shape_res): - lw, lb, uw, ub = x.lw, x.lb, x.uw, x.ub - shape_x, shape_res = list(x.lb.shape), list(shape_res) - if lw is None: - lw = uw = torch.zeros(dim_in, *shape_x, device=lb.device) - has_batch_size = False - else: - has_batch_size = True - while len(shape_x) < len(shape_res): - if not has_batch_size: - lw, uw = lw.unsqueeze(0), uw.unsqueeze(0) - lb, ub = lb.unsqueeze(0), ub.unsqueeze(0) - shape_x = [1] + shape_x - has_batch_size = True - else: - lw, uw = lw.unsqueeze(2), uw.unsqueeze(2) - lb, ub = lb.unsqueeze(1), ub.unsqueeze(1) - shape_x = [shape_x[0], 1] + shape_x[1:] - lb, ub = lb.expand(*shape_res), ub.expand(*shape_res) - lw = lw.expand(shape_res[0], lw.size(1), *shape_res[1:]) - uw = uw.expand(shape_res[0], uw.size(1), *shape_res[1:]) - return lw, lb, uw, ub - def get_bias(self, A, bias): if A is None: return 0 @@ -340,8 +295,7 @@ def get_bias(self, A, bias): if torch.isinf(bias).any(): warnings.warn('There is an inf value in the bias of LiRPA bounds.') - if isinstance(A, torch.Tensor): - output_dim = A.shape[0] + if isinstance(A, Tensor): if self.batch_dim != -1: bias_new = torch.einsum('sb...,b...->sb', A, bias) else: @@ -354,14 +308,19 @@ def get_bias(self, A, bias): else: # FIXME (09/17): handle the case for pieces.unstable_idx. return bias_new + elif isinstance(A, eyeC): + batch_size = A.shape[1] + if self.batch_dim != -1: + return bias.reshape(batch_size, -1).t() + else: + return bias.reshape(-1).unsqueeze(-1).repeat(1, batch_size) elif type(A) == Patches: # the shape of A.patches is [batch, L, out_c, in_c, K, K] if self.batch_dim != -1: # Input A patches has shape (spec, batch, out_h, out_w, in_c, H, W) or (unstable_size, batch, in_c, H, W). patches = A.patches - # Here the size of bias is [batch_size, out_h, out_w, in_c, H, W] - bias = inplace_unfold(bias, kernel_size=A.patches.shape[-2:], stride=A.stride, padding=A.padding) + bias = inplace_unfold(bias, kernel_size=A.patches.shape[-2:], stride=A.stride, padding=A.padding, inserted_zeros=A.inserted_zeros, output_padding=A.output_padding) if A.unstable_idx is not None: # Sparse bias has shape [unstable_size, batch_size, in_c, H, W]. No need to select over the out_c dimension. bias = bias[:, A.unstable_idx[1], A.unstable_idx[2]] @@ -379,42 +338,24 @@ def get_bias(self, A, bias): bias_new = torch.sum(patches, dim=(-1, -2, -3)) * bias.to(self.device) # Return shape is (spec, batch, out_h, out_w) or (unstable_size, batch). return bias_new - return bias_new else: return NotImplementedError() - @staticmethod - @torch.jit.script - def clamp_mutiply(A, pos, neg): - Apos = A.clamp(min=0) - Aneg = A.clamp(max=0) - return pos.contiguous() * Apos + neg.contiguous() * Aneg, Apos, Aneg - - @staticmethod - @torch.jit.script - def clamp_mutiply_non_contiguous(A, pos, neg): - Apos = A.clamp(min=0) - Aneg = A.clamp(max=0) - return pos * Apos + neg * Aneg, Apos, Aneg - - """save input and output shapes uniformly by the decorator""" - @staticmethod - def save_io_shape(func): - def wrapper(self, *args, **kwargs): - if len(args) > 0: - self.input_shape = args[0].shape # x should always be the first input - - output = func(self, *args, **kwargs) - - if isinstance(output, torch.Tensor): - self.output_shape = output.shape - return output - - return wrapper + def make_axis_non_negative(self, axis, shape='input'): + if shape == 'input': + shape = self.input_shape + elif shape == 'output': + shape = self.output_shape + else: + assert isinstance(shape, torch.Size) + if axis < 0: + return axis + len(shape) + else: + return axis - """Some operations are non-deterministic and deterministic mode will fail. So we temporary disable it.""" def non_deter_wrapper(self, op, *args, **kwargs): + """Some operations are non-deterministic and deterministic mode will fail. So we temporary disable it.""" if self.options.get('deterministic', False): torch.use_deterministic_algorithms(False) ret = op(*args, **kwargs) diff --git a/auto_LiRPA/operators/bivariate.py b/auto_LiRPA/operators/bivariate.py index 8ec0936..ffda86e 100644 --- a/auto_LiRPA/operators/bivariate.py +++ b/auto_LiRPA/operators/bivariate.py @@ -1,14 +1,33 @@ """ Bivariate operators""" from .base import * -from .activation import BoundSqrt, BoundReciprocal +from .nonlinear import BoundSqrt, BoundReciprocal +from .clampmult import multiply_by_A_signs +from ..utils import * +from .solver_utils import grb +from .constant import BoundConstant +from .leaf import BoundParams, BoundBuffers class BoundMul(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.nonlinear = True + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.is_constant_op = False + for inp in inputs: + if BoundMul._check_const_input(inp): + # If any of the two inputs are constant, we do not need input bounds. + # FIXME (05/11/2022): this is just a temporary workaround. We need better way to determine whether we need input bounds, not just for BoundConstant. + self.is_constant_op = True + if self.is_constant_op: + # One input is constant; no bounds required. + self.requires_input_bounds = [] + else: + # Both inputs are perturbed. Need relaxation. + self.requires_input_bounds = [0, 1] + + @staticmethod + def _check_const_input(inp): + return isinstance(inp, (BoundConstant, BoundBuffers)) or (isinstance(inp, BoundParams) and inp.perturbation is None) - @Bound.save_io_shape def forward(self, x, y): self.x_shape = x.shape self.y_shape = y.shape @@ -23,7 +42,6 @@ def get_bound_mul(x_l, x_u, y_l, y_u): alpha_u = y_u beta_u = x_l gamma_u = -alpha_u * beta_u - return alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u # Special case when input is x * x. @@ -58,17 +76,62 @@ def _relax(x, y): x_l, x_u = x.lower, x.upper y_l, y_u = y.lower, y.upper - # broadcast - for k in [1, -1]: - x_l = x_l + k * y_l - x_u = x_u + k * y_u - for k in [1, -1]: - y_l = y_l + k * x_l - y_u = y_u + k * x_u + # Broadcast + x_l = x_l + torch.zeros_like(y_l) + x_u = x_u + torch.zeros_like(y_u) + y_l = y_l + torch.zeros_like(x_l) + y_u = y_u + torch.zeros_like(x_u) return BoundMul.get_bound_mul(x_l, x_u, y_l, y_u) + @staticmethod + def _multiply_by_const(x, const): + if isinstance(x, torch.Tensor): + return x * const + elif isinstance(x, Patches): + # Multiplies patches by a const. Assuming const is a tensor, and it must be in nchw format. + assert isinstance(const, torch.Tensor) and const.ndim == 4 + if const.size(0) == x.patches.size(1) and const.size(1) == x.patches.size(-3) and const.size(2) == const.size(3) == 1: + # The case that we can do channel-wise broadcasting multiplication + # Shape of const: (batch, in_c, 1, 1) + # Shape of patches when unstable_idx is None: (spec, batch, in_c, patch_h, patch_w) + # Shape of patches when unstable_idx is not None: (out_c, batch, out_h, out_w, in_c, patch_h, patch_w) + const_reshaped = const + else: + assert x.unstable_idx is None and (x.padding == 0 or x.padding == [0,0,0,0]) and x.stride == 1 and x.patches.size(-1) == x.patches.size(-2) == 1 + # The assumed dimension is (out_c, N, out_h, out_w, in_c, 1, 1) with padding =1 and stride = 0. + # In this special case we can directly multiply. + # After reshape it is (1, N, H, W, C, 1, 1) + const_reshaped = const.permute(0, 2, 3, 1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + return x.create_similar(x.patches * const_reshaped) + else: + raise ValueError(f'Unsupported x type {type(x)}') + + @staticmethod + def bound_backward_constant(last_lA, last_uA, x, y, op=None): + op = BoundMul._multiply_by_const if op is None else op + # Handle the case of multiplication by a constant. + factor = None + if not BoundMul._check_const_input(x): + factor = y.value + if not BoundMul._check_const_input(y): + factor = x.value + # No need to compute A matrix if it is Constant. + lAx = None if BoundMul._check_const_input(x) or last_lA is None else op(last_lA, factor) + lAy = None if BoundMul._check_const_input(y) or last_lA is None else op(last_lA, factor) + uAx = None if BoundMul._check_const_input(x) or last_uA is None else op(last_uA, factor) + uAy = None if BoundMul._check_const_input(y) or last_uA is None else op(last_uA, factor) + + return [(lAx, uAx), (lAy, uAy)], 0., 0. + + def bound_backward(self, last_lA, last_uA, x, y): + if self.is_constant_op: + return self.bound_backward_constant(last_lA, last_uA, x, y) + else: + return self.bound_backward_both_perturbed(last_lA, last_uA, x, y) + + def bound_backward_both_perturbed(self, last_lA, last_uA, x, y): alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u = BoundMul._relax(x, y) alpha_l, alpha_u = alpha_l.unsqueeze(0), alpha_u.unsqueeze(0) @@ -79,14 +142,78 @@ def _bound_oneside(last_A, alpha_neg, beta_neg, gamma_neg): if last_A is None: return None, None, 0 - last_A_pos, last_A_neg = last_A.clamp(min=0), last_A.clamp(max=0) - A_x = last_A_pos * alpha_pos + last_A_neg * alpha_neg - A_y = last_A_pos * beta_pos + last_A_neg * beta_neg - last_A = last_A.reshape(last_A.shape[0], last_A.shape[1], -1) - A_x = self.broadcast_backward(A_x, x) - A_y = self.broadcast_backward(A_y, y) - bias = self.get_bias(last_A_pos, gamma_pos) + \ - self.get_bias(last_A_neg, gamma_neg) + + if type(last_A) == Patches: + # In patches mode, we need to unfold lower and upper slopes. In matrix mode we simply return. + def _maybe_unfold(d_tensor, last_A): + if d_tensor is None: + return None + + d_shape = d_tensor.size() + # Reshape to 4-D tensor to unfold. + d_tensor = d_tensor.view(-1, *d_shape[-3:]) + # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W). + d_unfolded = inplace_unfold(d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding) + # Reshape to (spec, batch, out_h, out_w, in_c, H, W); here spec_size is out_c. + d_unfolded_r = d_unfolded.view(*last_A.shape[:3], *d_unfolded.shape[1:]) + if last_A.unstable_idx is not None: + if d_unfolded_r.size(0) == 1: + # Broadcast the spec shape, so only need to select the reset dimensions. + # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W). + d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5) + d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] + # output shape: (unstable_size, batch, in_c, H, W). + else: + d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]] + # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W). + # For regular patches, the shape after unfold is (spec, batch, out_h, out_w, in_c, H, W). + return d_unfolded_r + # if last_A is not an identity matrix + assert last_A.identity == 0 + if last_A.identity == 0: + # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension. + # for patches mode, we need to unfold the alpha_pos/neg and beta_pos/neg + + alpha_pos = _maybe_unfold(alpha_pos, last_A) + alpha_neg = _maybe_unfold(alpha_neg, last_A) + beta_pos = _maybe_unfold(beta_pos, last_A) + beta_neg = _maybe_unfold(beta_neg, last_A) + + gamma_pos = _maybe_unfold(gamma_pos, last_A) + gamma_neg = _maybe_unfold(gamma_neg, last_A) + + patches = last_A.patches + patches_shape = patches.shape + A_x, bias = multiply_by_A_signs(patches.view(*patches_shape[:5], -1, *patches_shape[-2:]), alpha_pos, alpha_neg, gamma_pos, gamma_neg, patches_mode=True) + A_y, _ = multiply_by_A_signs(patches.view(*patches_shape[:5], -1, *patches_shape[-2:]), beta_pos, beta_neg, None, None, patches_mode=True) + A_x = A_x.view(patches_shape) + A_y = A_y.view(patches_shape) + + # broadcast_backward + x_dims = [] + y_dims = [] + + if A_x.shape[A_x.ndim-4] != x.output_shape[len(x.output_shape)-4]: + x_dims.append(A_x.ndim-4) + + if A_y.shape[A_y.ndim-4] != y.output_shape[len(y.output_shape)-4]: + y_dims.append(A_y.ndim-4) + + if len(x_dims) > 0: + A_x = A_x.sum(tuple(x_dims), keepdim=True) + if len(y_dims) > 0: + A_y = A_y.sum(tuple(y_dims), keepdim=True) + + A_x = Patches(A_x, last_A.stride, last_A.padding, A_x.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape) + A_y = Patches(A_y, last_A.stride, last_A.padding, A_y.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape) + if type(last_A) == Tensor: + last_A_pos, last_A_neg = last_A.clamp(min=0), last_A.clamp(max=0) + A_x = last_A_pos * alpha_pos + last_A_neg * alpha_neg + A_y = last_A_pos * beta_pos + last_A_neg * beta_neg + A_x = self.broadcast_backward(A_x, x) + A_y = self.broadcast_backward(A_y, y) + bias = self.get_bias(last_A_pos, gamma_pos) + \ + self.get_bias(last_A_neg, gamma_neg) return A_x, A_y, bias lA_x, lA_y, lbias = _bound_oneside( @@ -96,8 +223,13 @@ def _bound_oneside(last_A, return [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias + def bound_forward(self, dim_in, x, y): + if self.is_constant_op: + raise NotImplementedError + return self.bound_forward_both_perturbed(dim_in, x, y) + @staticmethod - def bound_forward(dim_in, x, y): + def bound_forward_both_perturbed(dim_in, x, y): x_lw, x_lb, x_uw, x_ub = x.lw, x.lb, x.uw, x.ub y_lw, y_lb, y_uw, y_ub = y.lw, y.lb, y.uw, y.ub @@ -120,7 +252,28 @@ def bound_forward(dim_in, x, y): return LinearBound(lw, lb, uw, ub) @staticmethod - def interval_propagate(*v): + def interval_propagate_constant(*v, op=lambda x, const: x * const): + x, y = v[0], v[1] + x_is_const = x[0] is x[1] # FIXME: using better way to represent constant perturbation. + y_is_const = y[0] is y[1] # We should not check the distance between x[0] and x[1]. It's slow! + assert x_is_const or y_is_const + const = x[0] if x_is_const else y[0] + inp_lb = x[0] if y_is_const else y[0] + inp_ub = x[1] if y_is_const else y[1] + pos_mask = (const > 0).to(dtype=inp_lb.dtype) + neg_mask = 1. - pos_mask + lb = op(inp_lb, const * pos_mask) + op(inp_ub, const * neg_mask) + ub = op(inp_ub, const * pos_mask) + op(inp_lb, const * neg_mask) + return lb, ub + + def interval_propagate(self, *v): + if self.is_constant_op: + return self.interval_propagate_constant(*v) + else: + return self.interval_propagate_both_perturbed(*v) + + @staticmethod + def interval_propagate_both_perturbed(*v): x, y = v[0], v[1] if x is y: # A shortcut for x * x. @@ -133,27 +286,16 @@ def interval_propagate(*v): l = F.relu(h_L) - F.relu(-h_U) return l * l, torch.max(r0, r1) - if Interval.use_relative_bounds(x) and Interval.use_relative_bounds(y): - nominal = x.nominal * y.nominal - lower_offset = ( - x.nominal.clamp(min=0) * (y.lower_offset) + - x.nominal.clamp(max=0) * (y.upper_offset) + - y.nominal.clamp(min=0) * (x.lower_offset) + - y.nominal.clamp(max=0) * (x.upper_offset) + - torch.min(x.lower_offset * y.upper_offset, x.upper_offset * y.lower_offset)) - upper_offset = ( - x.nominal.clamp(min=0) * (y.upper_offset) + - x.nominal.clamp(max=0) * (y.lower_offset) + - y.nominal.clamp(min=0) * (x.upper_offset) + - y.nominal.clamp(max=0) * (x.lower_offset) + - torch.max(x.lower_offset * y.lower_offset, x.upper_offset * y.upper_offset)) - return Interval(None, None, nominal=nominal, lower_offset=lower_offset, upper_offset=upper_offset) - r0, r1, r2, r3 = x[0] * y[0], x[0] * y[1], x[1] * y[0], x[1] * y[1] lower = torch.min(torch.min(r0, r1), torch.min(r2, r3)) upper = torch.max(torch.max(r0, r1), torch.max(r2, r3)) return lower, upper + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + for vi in v: + assert isinstance(vi, Tensor), "build solver for BoundMul only with tensors for now" + self.solver_vars = v[0] * v[1] + @staticmethod def infer_batch_dim(batch_size, *x): if x[0] == -1: @@ -166,13 +308,24 @@ def infer_batch_dim(batch_size, *x): class BoundDiv(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.nonlinear = True + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.is_constant_op = False + for inp in inputs: + if isinstance(inp, (BoundConstant, BoundBuffers)): + # If any of the two inputs are constant, we do not need input bounds. + # FIXME (05/11/2022): this is just a temporary workaround. We need better way to determine whether we need input bounds, not just for BoundConstant. + # FIXME: unify this handling with BoundMul. + self.is_constant_op = True + if self.is_constant_op: + # One input is constant; no bounds required. + self.requires_input_bounds = [] + else: + # Both inputs are perturbed. Need relaxation. + self.requires_input_bounds = [0, 1] - @Bound.save_io_shape def forward(self, x, y): - # ad-hoc implementation for layer normalization + # FIXME (05/11/2022): ad-hoc implementation for layer normalization if isinstance(self.inputs[1], BoundSqrt): input = self.inputs[0].inputs[0] x = input.forward_value @@ -188,21 +341,37 @@ def forward(self, x, y): return x / y def bound_backward(self, last_lA, last_uA, x, y): + if self.is_constant_op: + return BoundMul.bound_backward_constant(last_lA, last_uA, x, y, op=lambda x, const: BoundMul._multiply_by_const(x, 1/const)) + else: + return self.bound_backward_both_perturbed(last_lA, last_uA, x, y) + + def bound_backward_both_perturbed(self, last_lA, last_uA, x, y): reciprocal, mul, y_r = self._convert_to_mul(x, y) A, lower_b, upper_b = mul.bound_backward(last_lA, last_uA, x, y_r) - A_y, lower_b_y, upper_b_y = reciprocal.bound_backward(A[1][0], A[1][1], y) + if isinstance(upper_b_y, Tensor) and upper_b_y.ndim == 1: + upper_b_y = upper_b_y.unsqueeze(-1) + if isinstance(lower_b_y, Tensor) and lower_b_y.ndim == 1: + lower_b_y = lower_b_y.unsqueeze(-1) upper_b = upper_b + upper_b_y lower_b = lower_b + lower_b_y - return [A[0], A_y[0]], lower_b, upper_b def bound_forward(self, dim_in, x, y): + assert not self.is_constant_op reciprocal, mul, y_r = self._convert_to_mul(x, y) y_r_linear = reciprocal.bound_forward(dim_in, y) - y_r_linear = y_r_linear._replace(lower=y_r.lower, upper=y_r.upper) + y_r_linear.lower = y_r.lower + y_r_linear.upper = y_r.upper return mul.bound_forward(dim_in, x, y_r_linear) + def interval_propagate(self, *v): + if self.is_constant_op: + return BoundMul.interval_propagate_constant(*v, op=lambda x, const: x / const) + else: + return self.interval_propagate_both_perturbed(*v) + def interval_propagate(self, *v): # ad-hoc implementation for layer normalization """ @@ -215,28 +384,28 @@ def interval_propagate(self, *v): 1 / ( sqrt (1/n * sum_j Upper{(x_j-mu)^2/(x_i-mu)^2} )) Lower{(x_j-mu)^2/(x_i-mu)^2} - Lower{sum_j (x_j-mu)^2} / Upper{(x_i-mu)^2} + Lower{sum_j (x_j-mu)^2} / Upper{(x_i-mu)^2} Upper{(x_j-mu)^2/(x_i-mu)^2} - Upper{sum_j (x_j-mu)^2} / Lower{(x_i-mu)^2} - """ + Upper{sum_j (x_j-mu)^2} / Lower{(x_i-mu)^2} + """ if isinstance(self.inputs[1], BoundSqrt): input = self.inputs[0].inputs[0] n = input.forward_value.shape[-1] - + h_L, h_U = input.lower, input.upper dev_lower = ( - h_L * (1 - 1. / n) - + h_L * (1 - 1. / n) - (h_U.sum(dim=-1, keepdim=True) - h_U) / n ) dev_upper = ( - h_U * (1 - 1. / n) - + h_U * (1 - 1. / n) - (h_L.sum(dim=-1, keepdim=True) - h_L) / n ) - dev_sqr_lower = (1 - (dev_lower < 0).float() * (dev_upper > 0).float()) * \ - torch.min(dev_lower.abs(), dev_upper.abs())**2 + dev_sqr_lower = (1 - (dev_lower < 0).to(dev_lower.dtype) * (dev_upper > 0).to(dev_lower.dtype)) * \ + torch.min(dev_lower.abs(), dev_upper.abs())**2 dev_sqr_upper = torch.max(dev_lower.abs(), dev_upper.abs())**2 sum_lower = (dev_sqr_lower.sum(dim=-1, keepdim=True) - dev_sqr_lower) / dev_sqr_upper.clamp(min=epsilon) @@ -245,8 +414,8 @@ def interval_propagate(self, *v): dev_sqr_lower.clamp(min=epsilon) sqrt_upper = torch.sqrt(1. / n * (sum_upper + 1)) - lower = (dev_lower < 0).float() * (-1. / sqrt_lower) + (dev_lower > 0).float() * (1. / sqrt_upper) - upper = (dev_upper > 0).float() * (1. / sqrt_lower) + (dev_upper < 0).float() * (-1. / sqrt_upper) + lower = (dev_lower < 0).to(dev_lower.dtype) * (-1. / sqrt_lower) + (dev_lower > 0).to(dev_lower.dtype) * (1. / sqrt_upper) + upper = (dev_upper > 0).to(dev_upper.dtype) * (1. / sqrt_lower) + (dev_upper < 0).to(dev_upper.dtype) * (-1. / sqrt_upper) return lower, upper @@ -256,33 +425,35 @@ def interval_propagate(self, *v): def _convert_to_mul(self, x, y): try: - reciprocal = BoundReciprocal(self.input_name, self.name + '/reciprocal', self.ori_name, {}, [], 0, None, - self.device) - mul = BoundMul(self.input_name, self.name + '/mul', self.ori_name, {}, [], 0, None, self.device) + reciprocal = BoundReciprocal({}, [], 0, None) + mul = BoundMul({}, [], 0, None) except: # to make it compatible with previous code - reciprocal = BoundReciprocal(self.input_name, self.name + '/reciprocal', None, {}, [], 0, None, self.device) - mul = BoundMul(self.input_name, self.name + '/mul', None, {}, [], 0, None, self.device) + reciprocal = BoundReciprocal(None, {}, [], 0, None) + mul = BoundMul(None, {}, [], 0, None) reciprocal.output_shape = mul.output_shape = self.output_shape reciprocal.batch_dim = mul.batch_dim = self.batch_dim y_r = copy.copy(y) if isinstance(y_r, LinearBound): - y_r = y_r._replace(lower=1. / y.upper, upper=1. / y.lower) + y_r.lower = 1. / y.upper + y_r.upper = 1. / y.lower else: y_r.lower = 1. / y.upper y_r.upper = 1. / y.lower return reciprocal, mul, y_r - def infer_batch_dim(self, batch_size, *x): - return BoundMul.infer_batch_dim(batch_size, *x) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + for vi in v: + assert isinstance(vi, Tensor), "build solver for BoundDiv only with tensors for now" + self.solver_vars = v[0] / v[1] class BoundAdd(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + # FIXME: This is not the right way to enable patches mode. Instead we must traverse the graph and determine when patches mode needs to be used. self.mode = options.get("conv_mode", "matrix") - @Bound.save_io_shape def forward(self, x, y): self.x_shape = x.shape self.y_shape = y.shape @@ -301,33 +472,59 @@ def _bound_oneside(last_A, w): return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0 def bound_forward(self, dim_in, x, y): - x_lw, x_lb, x_uw, x_ub = Bound.broadcast_forward(dim_in, x, self.output_shape) - y_lw, y_lb, y_uw, y_ub = Bound.broadcast_forward(dim_in, y, self.output_shape) - lw, lb = x_lw + y_lw, x_lb + y_lb - uw, ub = x_uw + y_uw, x_ub + y_ub - return LinearBound(lw, lb, uw, ub) + lb, ub = x.lb + y.lb, x.ub + y.ub - def interval_propagate(self, x, y): - assert (not isinstance(y, torch.Tensor)) + def add_w(x_w, y_w, x_b, y_b): + if x_w is None and y_w is None: + return None + elif x_w is not None and y_w is not None: + return x_w + y_w + elif y_w is None: + return x_w + torch.zeros_like(y_b) + else: + return y_w + torch.zeros_like(x_b) - if Interval.use_relative_bounds(x) and Interval.use_relative_bounds(y): - return Interval( - None, None, - x.nominal + y.nominal, - x.lower_offset + y.lower_offset, - x.upper_offset + y.upper_offset) + lw = add_w(x.lw, y.lw, x.lb, y.lb) + uw = add_w(x.uw, y.uw, x.ub, y.ub) - return x[0] + y[0], x[1] + y[1] + return LinearBound(lw, lb, uw, ub) - def infer_batch_dim(self, batch_size, *x): - return BoundMul.infer_batch_dim(batch_size, *x) + def interval_propagate(self, x, y): + assert (not isinstance(y, Tensor)) + return x[0] + y[0], x[1] + y[1] + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + if isinstance(v[0], Tensor) and isinstance(v[1], Tensor): + # constants if both inputs are tensors + self.solver_vars = self.forward(v[0], v[1]) + return + # we have both gurobi vars as inputs + this_layer_shape = self.output_shape + gvar_array1 = np.array(v[0]) + gvar_array2 = np.array(v[1]) + assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:] + + # flatten to create vars and constrs first + gvar_array1 = gvar_array1.reshape(-1) + gvar_array2 = gvar_array2.reshape(-1) + new_layer_gurobi_vars = [] + for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)): + var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'lay{self.name}_{neuron_idx}') + model.addConstr(var == (var1 + var2), name=f'lay{self.name}_{neuron_idx}_eq') + new_layer_gurobi_vars.append(var) + + # reshape to the correct list shape of solver vars + self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist() + model.update() class BoundSub(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + # FIXME: This is not the right way to enable patches mode. Instead we must traverse the graph and determine when patches mode needs to be used. + self.mode = options.get("conv_mode", "matrix") - @Bound.save_io_shape def forward(self, x, y): self.x_shape = x.shape self.y_shape = y.shape @@ -337,7 +534,17 @@ def bound_backward(self, last_lA, last_uA, x, y): def _bound_oneside(last_A, w, sign=-1): if last_A is None: return None - return self.broadcast_backward(sign * last_A, w) + if isinstance(last_A, torch.Tensor): + return self.broadcast_backward(sign * last_A, w) + elif isinstance(last_A, Patches): + if sign == 1: + # Patches shape requires no broadcast. + return last_A + else: + # Multiply by the sign. + return last_A.create_similar(sign * last_A.patches) + else: + raise ValueError(f'Unknown last_A type {type(last_A)}') uA_x = _bound_oneside(last_uA, x, sign=1) uA_y = _bound_oneside(last_uA, y, sign=-1) @@ -346,32 +553,55 @@ def _bound_oneside(last_A, w, sign=-1): return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0 def bound_forward(self, dim_in, x, y): - x_lw, x_lb, x_uw, x_ub = Bound.broadcast_forward(dim_in, x, self.output_shape) - y_lw, y_lb, y_uw, y_ub = Bound.broadcast_forward(dim_in, y, self.output_shape) - lw, lb = x_lw - y_uw, x_lb - y_ub - uw, ub = x_uw - y_lw, x_ub - y_lb + lb, ub = x.lb - y.ub, x.ub - y.lb + + def add_w(x_w, y_w, x_b, y_b): + if x_w is None and y_w is None: + return None + elif x_w is not None and y_w is not None: + return x_w + y_w + elif y_w is None: + return x_w + torch.zeros_like(y_b) + else: + return y_w + torch.zeros_like(x_b) + + lw = add_w(x.lw, -y.uw, x.lb, y.lb) + uw = add_w(x.uw, -y.lw, x.ub, y.ub) + return LinearBound(lw, lb, uw, ub) def interval_propagate(self, x, y): - if Interval.use_relative_bounds(x) and Interval.use_relative_bounds(y): - return Interval( - None, None, - x.nominal - y.nominal, - x.lower_offset - y.upper_offset, - x.upper_offset - y.lower_offset) - return x[0] - y[1], x[1] - y[0] - def infer_batch_dim(self, batch_size, *x): - return BoundMul.infer_batch_dim(batch_size, *x) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + if isinstance(v[0], Tensor) and isinstance(v[1], Tensor): + # constants if both inputs are tensors + self.solver_vars = self.forward(v[0], v[1]) + return + # we have both gurobi vars as inputs + this_layer_shape = self.output_shape + gvar_array1 = np.array(v[0]) + gvar_array2 = np.array(v[1]) + assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:] + + # flatten to create vars and constrs first + gvar_array1 = gvar_array1.reshape(-1) + gvar_array2 = gvar_array2.reshape(-1) + new_layer_gurobi_vars = [] + for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)): + var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'lay{self.name}_{neuron_idx}') + model.addConstr(var == (var1 - var2), name=f'lay{self.name}_{neuron_idx}_eq') + new_layer_gurobi_vars.append(var) + + # reshape to the correct list shape of solver vars + self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist() + model.update() class BoundEqual(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) - @Bound.save_io_shape def forward(self, x, y): return x == y - - def infer_batch_dim(self, batch_size, *x): - return BoundMul.infer_batch_dim(batch_size, *x) \ No newline at end of file diff --git a/auto_LiRPA/operators/clampmult.py b/auto_LiRPA/operators/clampmult.py new file mode 100644 index 0000000..7241fb6 --- /dev/null +++ b/auto_LiRPA/operators/clampmult.py @@ -0,0 +1,238 @@ +"""Element multiplication with the A matrix based on its sign.""" +import torch +import time +from typing import Optional, Tuple +from torch import Tensor +from ..patches import Patches + + +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) + + +# @torch.jit.script +def _reference_multiply_by_A_signs(A: Tensor, d_pos: Tensor, d_neg: Tensor, + b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool) -> Tuple[Tensor, Tensor]: + """Reference implementation.""" + A_pos = A.clamp(min=0) + A_neg = A.clamp(max=0) + A_new = d_pos * A_pos + d_neg * A_neg + bias_pos = bias_neg = torch.tensor(0.) + if b_pos is not None: + if patches_mode: + bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos) + else: + bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos) + if b_neg is not None: + if patches_mode: + bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg) + else: + bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg) + return A_new, bias_pos + bias_neg + + +class ClampedMultiplication(torch.autograd.Function): + @staticmethod + @torch.jit.script + def clamp_mutiply_forward(A: Tensor, d_pos: Tensor, d_neg: Tensor, + b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool) -> Tuple[Tensor, Tensor]: + """Forward operations; actually the same as the reference implementation.""" + A_pos = A.clamp(min=0) + A_neg = A.clamp(max=0) + A_new = d_pos * A_pos + d_neg * A_neg + bias_pos = bias_neg = torch.tensor(0.) + if b_pos is not None: + if patches_mode: + bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos) + else: + bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos) + if b_neg is not None: + if patches_mode: + bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg) + else: + bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg) + return A_new, bias_pos + bias_neg + + @staticmethod + @torch.jit.script + def clamp_mutiply_backward(A: Tensor, d_pos: Tensor, d_neg: Tensor, + b_pos: Optional[Tensor], b_neg: Optional[Tensor], grad_output_A: Tensor, grad_output_bias: Optional[Tensor], + patches_mode: bool) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], None]: + """Improved backward operation. This should be better than the backward function generated by Pytorch.""" + if grad_output_bias is not None: + extension_dim = len(A.shape) - len(grad_output_bias.shape) + grad_output_bias = grad_output_bias.view(grad_output_bias.shape + (1, ) * extension_dim) + A_pos_mask = (A >= 0).to(dtype=grad_output_A.dtype) + A_neg_mask = 1. - A_pos_mask + A_pos_grad_output_A = A_pos_mask * grad_output_A + A_neg_grad_output_A = A_neg_mask * grad_output_A + gd_pos = A * A_pos_grad_output_A + gd_neg = A * A_neg_grad_output_A + if b_pos is not None and b_neg is not None and grad_output_bias is not None: + A_pos_grad_output_bias = A_pos_mask * grad_output_bias + A_neg_grad_output_bias = A_neg_mask * grad_output_bias + gb_neg = A * A_neg_grad_output_bias + gb_pos = A * A_pos_grad_output_bias + # gA has 4 terms. + gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_pos * A_pos_grad_output_bias + b_neg * A_neg_grad_output_bias + elif b_neg is not None and grad_output_bias is not None: + A_neg_grad_output_bias = A_neg_mask * grad_output_bias + gb_neg = A * A_neg_grad_output_bias + gb_pos = None + # gA has 3 terms. + gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_neg * A_neg_grad_output_bias + elif b_pos is not None and grad_output_bias is not None: + A_pos_grad_output_bias = A_pos_mask * grad_output_bias + gb_pos = A * A_pos_grad_output_bias + gb_neg = None + # gA has 3 terms. + gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_pos * A_pos_grad_output_bias + else: + # gA has 2 terms. + gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + gb_pos = gb_neg = None + return gA, gd_pos, gd_neg, gb_pos, gb_neg, None + + @staticmethod + def forward(ctx, A, d_pos, d_neg, b_pos, b_neg, patches_mode): + # No need to save the intermediate A_pos, A_neg as they have been fused into the computation. + ctx.save_for_backward(A, d_pos, d_neg, b_pos, b_neg) + ctx.patches_mode = patches_mode + return ClampedMultiplication.clamp_mutiply_forward(A, d_pos, d_neg, b_pos, b_neg, patches_mode) + + @staticmethod + def backward(ctx, grad_output_A, grad_output_bias): + A, d_pos, d_neg, b_pos, b_neg = ctx.saved_tensors + patches_mode = ctx.patches_mode + return ClampedMultiplication.clamp_mutiply_backward(A, d_pos, d_neg, b_pos, b_neg, grad_output_A, grad_output_bias, patches_mode) + + +def multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, contiguous='auto'): + if isinstance(A, Tensor): + if contiguous is True or contiguous == 'auto': + # For dense mode, convert d_pos and d_neg to contiguous tensor by default. + d_pos = d_pos.contiguous() + d_neg = d_neg.contiguous() + if d_pos.ndim == 1: + # Special case for LSTM, the bias term is 1-dimension. (FIXME) + assert d_neg.ndim == 1 and b_pos.ndim == 1 and b_neg.ndim == 1 + new_A = A.clamp(min=0) * d_pos + A.clamp(max=0) * d_neg + new_bias = A.clamp(min=0) * b_pos + A.clamp(max=0) * b_neg + return new_A, new_bias + return ClampedMultiplication.apply(A, d_pos, d_neg, b_pos, b_neg, False) + elif isinstance(A, Patches): + if contiguous: + # For patches mode, do not convert d_pos and d_neg to contiguous tensor by default. + d_pos = d_pos.contiguous() + d_neg = d_neg.contiguous() + assert A.identity == 0 # TODO: handle the A.identity = 1 case. Currently not used. + patches = A.patches + patches_shape = patches.shape + # patches shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension. + # or (unstable_size, batch_size, in_c, H, W) when it is sparse. + if len(patches_shape) == 6: + patches = patches.view(*patches_shape[:2], -1, *patches_shape[-2:]) + d_pos = d_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) if d_pos is not None else None + d_neg = d_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) if d_neg is not None else None + b_pos = b_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_pos is not None else None + b_neg = b_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_neg is not None else None + # Apply the multiplication based on signs. + A_prod, bias = ClampedMultiplication.apply(patches, d_pos, d_neg, b_pos, b_neg, True) + # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse. + # For sparse patches the return bias size is (unstable_size, batch). + # For regular patches the return bias size is (spec, batch, out_h, out_w). + if len(patches_shape) == 6: + A_prod = A_prod.view(*patches_shape) + return A.create_similar(A_prod), bias + + +def _speed_test(A, d_pos, d_neg, b_pos, b_neg, patches_mode=False, n_test=20, warmup=3): + """Benchmarking function.""" + print(f'patches_mode = {patches_mode}, b_pos is {type(b_pos)}, b_neg is {type(b_neg)}') + total_ref = 0. + total_new = 0. + run = ['ref', 'new'] + for i in range(n_test): + ref_time = new_time = 0. + + if 'ref' in run: + torch.cuda.synchronize() + start = time.time() + ref_A, ref_bias = _reference_multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, patches_mode) + ref_loss = ref_A.sum() + ref_bias.sum() + ref_loss.backward() + torch.cuda.synchronize() + ref_time = time.time() - start + ref_gA = A.grad.detach().clone() + ref_gd_pos = d_pos.grad.detach().clone() + ref_gd_neg = d_neg.grad.detach().clone() + ref_gb_pos = b_pos.grad.detach().clone() if b_pos is not None else torch.tensor(0.) + ref_gb_neg = b_neg.grad.detach().clone() if b_neg is not None else torch.tensor(0.) + A.grad = d_pos.grad = d_neg.grad = None + if b_pos is not None: + b_pos.grad = None + if b_neg is not None: + b_neg.grad = None + del ref_loss + + if 'new' in run: + torch.cuda.synchronize() + start = time.time() + new_A, new_bias = multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, patches_mode) + new_loss = new_A.sum() + new_bias.sum() + new_loss.backward() + torch.cuda.synchronize() + new_time = time.time() - start + new_gA = A.grad.detach().clone() + new_gd_pos = d_pos.grad.detach().clone() + new_gd_neg = d_neg.grad.detach().clone() + new_gb_pos = b_pos.grad.detach().clone() if b_pos is not None else torch.tensor(0.) + new_gb_neg = b_neg.grad.detach().clone() if b_neg is not None else torch.tensor(0.) + A.grad = d_pos.grad = d_neg.grad = None + if b_pos is not None: + b_pos.grad = None + if b_neg is not None: + b_neg.grad = None + del new_loss + + print(f'Loop {i:3d} {"(warmup)" if i < warmup else " "} time ref {ref_time:.5f} new {new_time:.6f} speedup {ref_time / new_time if i >= warmup else float("nan"):.3f}') + if i >= warmup: + total_ref += ref_time + total_new += new_time + + if 'ref' in run and 'new' in run: + A_diff = (ref_A - new_A).abs().sum().item() / ref_A.abs().sum().item() + gA_diff = (ref_gA - new_gA).abs().sum().item() / ref_gA.abs().sum().item() + bias_diff = (ref_bias - new_bias).abs().sum().item() / (ref_bias.abs().sum().item() + 1e-10) + gd_pos_diff = (ref_gd_pos - new_gd_pos).abs().sum().item() / ref_gd_pos.abs().sum().item() + gd_neg_diff = (ref_gd_neg - new_gd_neg).abs().sum().item() / ref_gd_neg.abs().sum().item() + gb_pos_diff = (ref_gb_pos - new_gb_pos).abs().sum().item() / (ref_gb_pos.abs().sum().item() + 1e-10) + gb_neg_diff = (ref_gb_neg - new_gb_neg).abs().sum().item() / (ref_gb_neg.abs().sum().item() + 1e-10) + print(f' diff {A_diff} {gA_diff} {bias_diff} {gd_pos_diff} {gd_neg_diff} {gb_pos_diff} {gb_neg_diff}') + assert A_diff < 1e-6 and bias_diff < 1e-6 and gA_diff < 1e-6 and gd_pos_diff < 1e-6 and gd_neg_diff < 1e-6 + assert gb_pos_diff < 1e-6 and gb_neg_diff < 1e-6 + + + avg_ref_time = total_ref / (n_test - warmup) + avg_new_time = total_new / (n_test - warmup) + print(f'Avg. time: reference {avg_ref_time:.5f} new {avg_new_time:.6f} speedup {avg_ref_time / avg_new_time:.3f}') + + +if __name__ == '__main__': + for patches_mode in [True, False]: + if patches_mode: + shape = (256, 8, 8, 8, 16, 32) + else: + shape = (256, 8, 128, 256) + A = torch.randn(shape, device='cuda', requires_grad=True) + d_pos = torch.randn(shape, device='cuda', requires_grad=True) + d_neg = torch.randn(shape, device='cuda', requires_grad=True) + b_pos = torch.randn(shape, device='cuda', requires_grad=True) + b_neg = torch.randn(shape, device='cuda', requires_grad=True) + _speed_test(A, d_pos, d_neg, None, None, patches_mode=patches_mode) + _speed_test(A, d_pos, d_neg, None, b_neg, patches_mode=patches_mode) + _speed_test(A, d_pos, d_neg, b_pos, None, patches_mode=patches_mode) + _speed_test(A, d_pos, d_neg, b_pos, b_neg, patches_mode=patches_mode) + print('Press Enter key to continue.') + input() + del A, d_pos, d_neg, b_pos, b_neg diff --git a/auto_LiRPA/operators/constant.py b/auto_LiRPA/operators/constant.py index 6d01883..9b6b52f 100644 --- a/auto_LiRPA/operators/constant.py +++ b/auto_LiRPA/operators/constant.py @@ -2,12 +2,11 @@ from .base import * class BoundConstant(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.value = attr['value'].to(self.device) self.use_default_ibp = True - @Bound.save_io_shape def forward(self): return self.value.to(self.device) @@ -19,7 +18,7 @@ def _bound_oneside(A): if A is None: return 0.0 - if type(A) == torch.Tensor: + if type(A) == Tensor: if A.ndim > 2: A = torch.sum(A, dim=list(range(2, A.ndim))) elif type(A) == Patches: @@ -39,22 +38,22 @@ def bound_forward(self, dim_in): lb = ub = self.value return LinearBound(lw, lb, uw, ub) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.value + + class BoundPrimConstant(Bound): - def __init__(self, input_name, name, ori_name, attr, input, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, input, output_index, options, device) - self.value = attr['value'] + def __init__(self, attr, input, output_index, options): + super().__init__(attr, input, output_index, options) - @Bound.save_io_shape def forward(self): return torch.tensor([], device=self.device) class BoundConstantOfShape(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.device = device + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.value = attr['value'].to(self.device) - @Bound.save_io_shape def forward(self, x): self.x = x self.from_input = True @@ -85,9 +84,13 @@ def bound_forward(self, dim_in, x): def interval_propagate(self, *v): self.x = v[0][0] - value = torch.ones(list(v[0][0]), device=self.device) * self.value + size = int(v[0][0].item()) if isinstance(v[0][0], Tensor) else v[0][0] + value = torch.ones(size, device=self.device) * self.value return value, value + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.forward(v) + def infer_batch_dim(self, batch_size, *x): # FIXME Should avoid referring to batch_size; Treat `torch.Size` results differently if self.x[0] == batch_size: @@ -96,10 +99,10 @@ def infer_batch_dim(self, batch_size, *x): return -1 class BoundRange(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.device = attr['device'] - @Bound.save_io_shape def forward(self, start, end, step): if start.dtype == end.dtype == step.dtype == torch.int64: return torch.arange(start, end, step, dtype=torch.int64, device=self.device) @@ -111,10 +114,10 @@ def infer_batch_dim(self, batch_size, *x): return -1 class BoundATenDiag(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.device = attr['device'] - @Bound.save_io_shape def forward(self, x, diagonal=0): return torch.diag(x, diagonal=diagonal) @@ -125,10 +128,10 @@ def infer_batch_dim(self, batch_size, *x): return 1 # This is not a batch operation. class BoundATenDiagonal(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.device = attr['device'] - @Bound.save_io_shape def forward(self, x, offset=0, dim1=0, dim2=1): return torch.diagonal(x, offset=offset, dim1=dim1, dim2=dim2) diff --git a/auto_LiRPA/operators/convolution.py b/auto_LiRPA/operators/convolution.py index 6bc5d3f..22343f1 100644 --- a/auto_LiRPA/operators/convolution.py +++ b/auto_LiRPA/operators/convolution.py @@ -1,14 +1,16 @@ -""" Convolution, pooling and padding operators""" +""" Convolution and padding operators""" from .base import * -from .activation import BoundOptimizableActivation +import numpy as np +from .solver_utils import grb +from ..patches import unify_shape, compute_patches_stride_padding, is_shape_used class BoundConv(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): + def __init__(self, attr, inputs, output_index, options): assert (attr['pads'][0] == attr['pads'][2]) assert (attr['pads'][1] == attr['pads'][3]) - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + super().__init__(attr, inputs, output_index, options) self.stride = attr['strides'] self.padding = [attr['pads'][0], attr['pads'][1]] @@ -18,14 +20,13 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio self.has_bias = True else: self.has_bias = False - self.to(device) + self.relu_followed = False + self.patches_start = True self.mode = options.get("conv_mode", "matrix") - self.relu_followed = False # denote whether this Conv is followed by a ReLU - # if self.relu_followed is False, we need to manually pad the conv patches. + # if self.relu_followed is False, we need to manually pad the conv patches. # If self.relu_followed is True, the patches are padded in the ReLU layer and the manual padding is not needed. - @Bound.save_io_shape def forward(self, *x): # x[0]: input, x[1]: weight, x[2]: bias if self.has_bias bias = x[2] if self.has_bias else None @@ -46,14 +47,14 @@ def _bound_oneside(last_A): # Conv layer does not support the OneHotC fast path. We have to create a dense matrix instead. shape = last_A.shape # [spec, batch, C, H, W] dim = int(prod(shape[2:])) - dense_last_A = torch.zeros(size=(shape[0], shape[1], dim), device=last_A.device) + dense_last_A = torch.zeros(size=(shape[0], shape[1], dim), device=last_A.device, dtype=weight.dtype) # last_A.index has size (spec, batch), its values are the index of the one-hot non-zero elements in A. # last_A.coeffs is the value of the non-zero element. dense_last_A = torch.scatter(dense_last_A, dim=2, index=last_A.index.unsqueeze(-1), src=last_A.coeffs.unsqueeze(-1)) # We created a large A matrix and it will be handled below. last_A = dense_last_A.view(shape[0], shape[1], *shape[2:]) - if type(last_A) == torch.Tensor: + if type(last_A) == Tensor: shape = last_A.size() # when (W−F+2P)%S != 0, construct the output_padding output_padding0 = int(self.input_shape[2]) - (int(self.output_shape[2]) - 1) * self.stride[0] + 2 * \ @@ -65,7 +66,8 @@ def _bound_oneside(last_A): groups=self.groups, output_padding=(output_padding0, output_padding1)) next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:]) if self.has_bias: - sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2) + # sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2) + sum_bias = torch.einsum('sbchw,c->sb', last_A, x[2].lower) else: sum_bias = 0 return next_A, sum_bias @@ -76,13 +78,15 @@ def _bound_oneside(last_A): if not self.relu_followed: # FIXME (09/20): Don't call it relu_followed. Instead, make this a property of A, called "padded" and propagate this property. # The last_A.patches was not padded, so we need to pad them here. # If this Conv layer is followed by a ReLU layer, then the padding was already handled there and there is no need to pad again. - one_d = torch.ones(tuple(1 for i in self.output_shape), device=last_A.patches.device).expand(self.output_shape) - # After unfolding, the shape is (batch, out_h, out_w, in_c, h, w) - one_d_unfolded = inplace_unfold(one_d, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding) + one_d = torch.ones(tuple(1 for i in self.output_shape[1:]), device=last_A.patches.device, dtype=weight.dtype).expand(self.output_shape[1:]) + # Add batch dimension. + one_d = one_d.unsqueeze(0) + # After unfolding, the shape is (1, out_h, out_w, in_c, h, w) + one_d_unfolded = inplace_unfold(one_d, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding) if last_A.unstable_idx is not None: # Move out_h, out_w dimension to the front for easier selection. one_d_unfolded_r = one_d_unfolded.permute(1, 2, 0, 3, 4, 5) - # for sparse patches the shape is (unstable_size, batch, in_c, h, w). + # for sparse patches the shape is (unstable_size, batch, in_c, h, w). Batch size is 1 so no need to select here. one_d_unfolded_r = one_d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] else: # Append the spec dimension. @@ -100,7 +104,7 @@ def _bound_oneside(last_A): sum_bias = 0 flattened_patches = patches.reshape(-1, patches.size(-3), patches.size(-2), patches.size(-1)) - pieces = F.conv_transpose2d(flattened_patches, weight, stride=self.stride) + pieces = F.conv_transpose2d(flattened_patches, insert_zeros(weight, last_A.inserted_zeros), stride=self.stride) # New patch size: (out_c, batch, out_h, out_w, c, h, w) or (unstable_size, batch, c, h, w). pieces = pieces.view(*patches.shape[:-3], pieces.size(-3), pieces.size(-2), pieces.size(-1)) @@ -114,37 +118,43 @@ def _bound_oneside(last_A): # Expand the batch dimnension. pieces = pieces.expand(-1, last_A.shape[1], -1, -1, -1) # Do the same for the bias. - sum_bias = x[2].lower[last_A.unstable_idx[0]].unsqueeze(-1) - # bias has shape (unstable_size, batch). - sum_bias = sum_bias.expand(-1, last_A.shape[1]) + if self.has_bias: + sum_bias = x[2].lower[last_A.unstable_idx[0]].unsqueeze(-1) + # bias has shape (unstable_size, batch). + sum_bias = sum_bias.expand(-1, last_A.shape[1]) + else: + sum_bias = 0 else: assert weight.size(0) == last_A.shape[0] pieces = weight.view(weight.size(0), 1, 1, 1, weight.size(1), weight.size(2), weight.size(3)).expand(-1, *last_A.shape[1:4], -1, -1, -1) # The bias (x[2].lower) has shape (out_c,) need to make it (out_c, batch, out_h, out_w). # Here we should transpose sum_bias to set the batch dim to 1, aiming to keep it consistent with the matrix version - sum_bias = x[2].lower.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4]) + if self.has_bias: + sum_bias = x[2].lower.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4]) + else: + sum_bias = 0 else: raise NotImplementedError() padding = last_A.padding if last_A is not None else (0, 0, 0, 0) # (left, right, top, bottom) - stride = last_A.stride if last_A is not None else 1 + stride = last_A.stride if last_A is not None else (1, 1) + inserted_zeros = last_A.inserted_zeros if last_A is not None else 0 + output_padding = last_A.output_padding if last_A is not None else (0, 0, 0, 0) - if type(padding) == int: - padding = padding * self.stride[0] + self.padding[0] - else: - padding = tuple(p * self.stride[0] + self.padding[0] for p in padding) - stride *= self.stride[0] + padding, stride, output_padding = compute_patches_stride_padding(self.input_shape, padding, stride, self.padding, self.stride, inserted_zeros, output_padding) - if pieces.shape[-1] > self.input_shape[-1]: # the patches is too large and from now on, we will use matrix mode instead of patches mode. + if inserted_zeros == 0 and not is_shape_used(output_padding) and pieces.shape[-1] > self.input_shape[-1]: # the patches is too large and from now on, we will use matrix mode instead of patches mode. # This is our desired matrix: the input will be flattend to (batch_size, input_channel*input_x * input_y) and multiplies on this matrix. # After multiplication, the desired output is (batch_size, out_channel*output_x*output_y). # A_matrix has size (batch, out_c*out_h*out_w, in_c*in_h*in_w) A_matrix = patches_to_matrix(pieces, self.input_shape[1:], stride, padding, last_A.output_shape, last_A.unstable_idx) - if isinstance(sum_bias, torch.Tensor) and last_A.unstable_idx is None: + # print(f'Converting patches to matrix: old shape {pieces.shape}, size {pieces.numel()}; new shape {A_matrix.shape}, size {A_matrix.numel()}') + if isinstance(sum_bias, Tensor) and last_A.unstable_idx is None: sum_bias = sum_bias.transpose(0, 1) sum_bias = sum_bias.reshape(sum_bias.size(0), -1).transpose(0,1) A_matrix = A_matrix.transpose(0,1) # Spec dimension at the front. return A_matrix, sum_bias - return Patches(pieces, stride, padding, pieces.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape), sum_bias + # print(f'Conv returns patches with size={pieces.size()}, stride={stride}, padding={padding}, inserted_zeros={inserted_zeros}, output_padding={output_padding}') + return Patches(pieces, stride, padding, pieces.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape, inserted_zeros=last_A.inserted_zeros, output_padding=output_padding), sum_bias else: raise NotImplementedError() @@ -152,37 +162,102 @@ def _bound_oneside(last_A): uA_x, ubias = _bound_oneside(last_uA) return [(lA_x, uA_x), (lA_y, uA_y), (lA_bias, uA_bias)], lbias, ubias - def bound_forward(self, dim_in, *x): + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): if self.is_input_perturbed(1): raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.") - weight = x[1].lb - bias = x[2].lb if self.has_bias else None - x = x[0] - input_dim = x.lb.shape[-2] * x.lb.shape[-1] - wshape = x.lw.shape - eye = torch.eye(input_dim).view(input_dim, 1, *x.lb.shape[-2:]) - weight = F.conv2d(eye, weight, None, self.stride, self.padding, self.dilation, self.groups) - weight = weight.view(input_dim, -1) - output_dim = weight.shape[-1] - bias = bias.view(1, -1, 1).repeat(1, 1, output_dim // bias.shape[0]).view(*self.output_shape[1:]) - batch_size = x.lb.shape[0] - - lw = (x.lw.reshape(batch_size, dim_in, -1).matmul(weight.clamp(min=0)) + - x.uw.reshape(batch_size, dim_in, -1).matmul(weight.clamp(max=0)))\ - .reshape(batch_size, dim_in, *self.output_shape[1:]) - uw = (x.uw.reshape(batch_size, dim_in, -1).matmul(weight.clamp(min=0)) + - x.lw.reshape(batch_size, dim_in, -1).matmul(weight.clamp(max=0)))\ - .reshape(batch_size, dim_in, *self.output_shape[1:]) - - lb = (x.lb.reshape(batch_size, -1).matmul(weight.clamp(min=0)) + - x.ub.reshape(batch_size, -1).matmul(weight.clamp(max=0)))\ - .reshape(batch_size, *self.output_shape[1:]) + bias - ub = (x.ub.reshape(batch_size, -1).matmul(weight.clamp(min=0)) + - x.lb.reshape(batch_size, -1).matmul(weight.clamp(max=0)))\ - .reshape(batch_size, *self.output_shape[1:]) + bias - - return LinearBound(lw, lb, uw, ub) + assert self.dilation == (1, 1) or self.dilation == [1, 1] + # e.g., last layer input gurobi vars (3,32,32) + gvars_array = np.array(v[0]) + # pre_layer_shape (1,3,32,32) + pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape + # this layer shape (1,8,16,16) + this_layer_shape = self.output_shape + out_lbs, out_ubs = None, None + if hasattr(self, "lower"): + # self.lower shape (1,8,16,16) + out_lbs = self.lower.cpu().numpy() + out_ubs = self.upper.cpu().numpy() + + # current layer weight (8,3,4,4) + this_layer_weight = v[1].detach().cpu().numpy() + # current layer bias (8,) + this_layer_bias = None + if self.has_bias: + this_layer_bias = v[2].detach().cpu().numpy() + weight_shape2, weight_shape3 = this_layer_weight.shape[2], this_layer_weight.shape[3] + padding0, padding1 = self.padding[0], self.padding[1] + stride0, stride1 = self.stride[0], self.stride[1] + + new_layer_gurobi_vars = [] + new_layer_gurobi_constrs = [] + + neuron_idx = 0 + for out_chan_idx in range(this_layer_shape[1]): + out_chan_vars = [] + for out_row_idx in range(this_layer_shape[2]): + out_row_vars = [] + for out_col_idx in range(this_layer_shape[3]): + # print(this_layer_bias.shape, out_chan_idx, out_lbs.size(1)) + lin_expr = 0 + if self.has_bias: + lin_expr = this_layer_bias[out_chan_idx] + + for in_chan_idx in range(this_layer_weight.shape[1]): + + # new version of conv layer for building mip by skipping kernel loops + ker_row_min, ker_row_max = 0, weight_shape2 + in_row_idx_min = -padding0 + stride0 * out_row_idx + in_row_idx_max = in_row_idx_min + weight_shape2 - 1 + if in_row_idx_min < 0: + ker_row_min = -in_row_idx_min + if in_row_idx_max >= pre_layer_shape[2]: + ker_row_max = ker_row_max - in_row_idx_max + pre_layer_shape[2] -1 + in_row_idx_min, in_row_idx_max = max(in_row_idx_min, 0), min(in_row_idx_max, pre_layer_shape[2] - 1) + + ker_col_min, ker_col_max = 0, weight_shape3 + in_col_idx_min = -padding1 + stride1 * out_col_idx + in_col_idx_max = in_col_idx_min + weight_shape3 - 1 + if in_col_idx_min < 0: + ker_col_min = -in_col_idx_min + if in_col_idx_max >= pre_layer_shape[3]: + ker_col_max = ker_col_max - in_col_idx_max + pre_layer_shape[3] -1 + in_col_idx_min, in_col_idx_max = max(in_col_idx_min, 0), min(in_col_idx_max, pre_layer_shape[3] - 1) + + coeffs = this_layer_weight[out_chan_idx, in_chan_idx, ker_row_min:ker_row_max, ker_col_min:ker_col_max].reshape(-1) + + gvars = gvars_array[in_chan_idx, in_row_idx_min:in_row_idx_max+1, in_col_idx_min:in_col_idx_max+1].reshape(-1) + if solver_pkg == 'gurobi': + lin_expr += grb.LinExpr(coeffs, gvars) + else: + # lin_expr += coeffs@gvars + + for i in range(len(coeffs)): + try: + lin_expr += coeffs[i] * gvars[i] + except TypeError: + lin_expr += coeffs[i] * gvars[i].var + + + out_lb = out_lbs[0, out_chan_idx, out_row_idx, out_col_idx] if out_lbs is not None else -float('inf') + out_ub = out_ubs[0, out_chan_idx, out_row_idx, out_col_idx] if out_ubs is not None else float('inf') + var = model.addVar(lb=out_lb, ub=out_ub, + obj=0, vtype=grb.GRB.CONTINUOUS, + # name=f'lay{layer_idx}_[{out_chan_idx}, {out_row_idx}, {out_col_idx}]') + name=f'lay{self.name}_{neuron_idx}') + # model.addConstr(lin_expr == var, name=f'lay{layer_idx}_[{out_chan_idx}, {out_row_idx}, {out_col_idx}]_eq') + # new_layer_gurobi_constrs.append( + # model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq')) + model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq') + neuron_idx += 1 + + out_row_vars.append(var) + out_chan_vars.append(out_row_vars) + new_layer_gurobi_vars.append(out_chan_vars) + + self.solver_vars = new_layer_gurobi_vars + # self.solver_constrs = new_layer_gurobi_constrs + model.update() def interval_propagate(self, *v, C=None): if self.is_input_perturbed(1): @@ -191,32 +266,6 @@ def interval_propagate(self, *v, C=None): norm = Interval.get_perturbation(v[0]) norm = norm[0] - if Interval.use_relative_bounds(*v): - bias = v[2].nominal if self.has_bias else None - if norm == np.inf: - weight = v[1].nominal - nominal = F.conv2d( - v[0].nominal, weight, bias, - self.stride, self.padding, self.dilation, self.groups) - lower_offset = (F.conv2d( - v[0].lower_offset, weight.clamp(min=0), None, - self.stride, self.padding, self.dilation, self.groups) + - F.conv2d( - v[0].upper_offset, weight.clamp(max=0), None, - self.stride, self.padding, self.dilation, self.groups)) - upper_offset = (F.conv2d( - v[0].upper_offset, weight.clamp(min=0), None, - self.stride, self.padding, self.dilation, self.groups) + - F.conv2d( - v[0].lower_offset, weight.clamp(max=0), None, - self.stride, self.padding, self.dilation, self.groups)) - return Interval( - None, None, nominal=nominal, - lower_offset=lower_offset, upper_offset=upper_offset - ) - else: - raise NotImplementedError - h_L, h_U = v[0] weight = v[1][0] bias = v[2][0] if self.has_bias else None @@ -247,15 +296,47 @@ def interval_propagate(self, *v, C=None): ss = center.shape deviation = deviation.repeat(ss[2] * ss[3]).view(-1, ss[1]).t().view(ss[1], ss[2], ss[3]) - + center = F.conv2d(mid, weight, bias, self.stride, self.padding, self.dilation, self.groups) upper = center + deviation lower = center - deviation return lower, upper + def bound_dynamic_forward(self, *x, max_dim=None, offset=0): + if self.is_input_perturbed(1) or self.is_input_perturbed(2): + raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.") + weight = x[1].lb + bias = x[2].lb if self.has_bias else None + x = x[0] + w = x.lw + b = x.lb + shape = w.shape + shape_wconv = [shape[0] * shape[1]] + list(shape[2:]) + def conv2d(input, weight, bias, stride, padding, dilation, groups): + """ There may be some CUDA error (illegal memory access) when + the batch size is too large. Thus split the input into several + batches when needed. """ + max_batch_size = 50 + if input.device != torch.device('cpu') and input.shape[0] > max_batch_size: + ret = [] + for i in range((input.shape[0] + max_batch_size - 1) // max_batch_size): + ret.append(F.conv2d( + input[i*max_batch_size:(i+1)*max_batch_size], + weight, bias, stride, padding, dilation, groups)) + return torch.cat(ret, dim=0) + else: + return F.conv2d(input, weight, bias, stride, padding, dilation, groups) + w_new = conv2d( + w.reshape(shape_wconv), weight, None, self.stride, self.padding, + self.dilation, self.groups) + w_new = w_new.reshape(shape[0], -1, *w_new.shape[1:]) + b_new = conv2d( + b, weight, bias, self.stride, self.padding, self.dilation, self.groups) + return LinearBound(w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim) + def bound_forward(self, dim_in, *x): - if self.is_input_perturbed(1): + if self.is_input_perturbed(1) or self.is_input_perturbed(2): raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.") weight = x[1].lb @@ -270,16 +351,16 @@ def bound_forward(self, dim_in, *x): shape = mid_w.shape shape_wconv = [shape[0] * shape[1]] + list(shape[2:]) deviation_w = F.conv2d( - diff_w.reshape(shape_wconv), weight_abs, None, + diff_w.reshape(shape_wconv), weight_abs, None, self.stride, self.padding, self.dilation, self.groups) deviation_b = F.conv2d( - diff_b, weight_abs, None, + diff_b, weight_abs, None, self.stride, self.padding, self.dilation, self.groups) center_w = F.conv2d( - mid_w.reshape(shape_wconv), weight, None, + mid_w.reshape(shape_wconv), weight, None, self.stride, self.padding, self.dilation, self.groups) - center_b = F.conv2d( - mid_b, weight, bias, + center_b = F.conv2d( + mid_b, weight, bias, self.stride, self.padding, self.dilation, self.groups) deviation_w = deviation_w.reshape(shape[0], -1, *deviation_w.shape[1:]) center_w = center_w.reshape(shape[0], -1, *center_w.shape[1:]) @@ -290,164 +371,219 @@ def bound_forward(self, dim_in, *x): uw = center_w + deviation_w, ub = center_b + deviation_b) +class BoundConvTranspose(Bound): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + assert (attr['pads'][0] == attr['pads'][2]) + assert (attr['pads'][1] == attr['pads'][3]) -class BoundMaxPool(BoundOptimizableActivation): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2]) - assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3]) - - self.nonlinear = True - self.kernel_size = attr['kernel_shape'] self.stride = attr['strides'] self.padding = [attr['pads'][0], attr['pads'][1]] - self.ceil_mode = False - self.use_default_ibp = True - self.alpha = None - self.init = {} - - @Bound.save_io_shape - def forward(self, x): - output, _ = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + self.dilation = attr['dilations'] + self.groups = attr['group'] + self.output_padding = [attr.get('output_padding', [0, 0])[0], attr.get('output_padding', [0, 0])[1]] + if len(inputs) == 3: + self.has_bias = True + else: + self.has_bias = False + self.mode = options.get("conv_mode", "matrix") + assert self.output_padding == [0, 0] + assert self.padding == [0, 0] + assert self.dilation == [1, 1] + assert self.stride[0] == self.stride[1] + assert self.groups == 1 + + def forward(self, *x): + # x[0]: input, x[1]: weight, x[2]: bias if self.has_bias + bias = x[2] if self.has_bias else None + output = F.conv_transpose2d(x[0], x[1], bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding) return output - def project_simplex(self, patches): - sorted = torch.flatten(patches, -2) - sorted, _ = torch.sort(sorted, -1, descending=True) - rho_sum = torch.cumsum(sorted, -1) - rho_value = 1 - rho_sum - rho_value = (sorted + rho_value/torch.tensor(range(1, sorted.size(-1)+1), dtype=torch.float, device=sorted.device)) > 0 - _, rho_index = torch.max(torch.cumsum(rho_value, -1), -1) - rho_sum = torch.gather(rho_sum, -1, rho_index.unsqueeze(-1)).squeeze(-1) - lbd = 1/(rho_index+1)* (1-rho_sum) - - return torch.clamp(patches + lbd.unsqueeze(-1).unsqueeze(-1), min=0) - - def init_opt_parameters(self, start_nodes): - batch_size, channel, h, w = self.input_shape - o_h, o_w = self.output_shape[-2:] - # batch_size, out_c, out_h, out_w, k, k - - self.alpha = OrderedDict() - ref = self.inputs[0].lower # a reference variable for getting the shape - for ns, size_s in start_nodes: - self.alpha[ns] = torch.empty([1, size_s, self.input_shape[0], self.input_shape[1], self.output_shape[-2], self.output_shape[-1], self.kernel_size[0], self.kernel_size[1]], - dtype=torch.float, device=ref.device, requires_grad=True) - self.init[ns] = False - - def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None, unstable_idx=None): - paddings = tuple(self.padding + self.padding) - - A_shape = last_lA.shape if last_lA is not None else last_uA.shape - # batch_size, input_c, x, y - upper_d = torch.zeros((list(self.input_shape)), device=x.device) - lower_d = torch.zeros((list(self.input_shape)), device=x.device) - - upper_d = F.pad(upper_d, paddings) - lower_d = F.pad(lower_d, paddings) - - # batch_size, output_c, x, y - upper_b = torch.zeros((list(self.output_shape)), device=x.device) - lower_b = torch.zeros((list(self.output_shape)), device=x.device) - - # 1. find the index i where li > uj for all j, then set upper_d = lower_d = 1 - max_lower, max_lower_index = F.max_pool2d(x.lower, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) - delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape) - max_upper, _ = F.max_pool2d(delete_upper, self.kernel_size, self.stride, 0, return_indices=True, ceil_mode=self.ceil_mode) - - values = torch.zeros_like(max_lower) - values[max_lower >= max_upper] = 1.0 - upper_d = torch.scatter(torch.flatten(upper_d, -2), -1, torch.flatten(max_lower_index, -2), torch.flatten(values, -2)).view(upper_d.shape) - - if self.opt_stage == 'opt': - if unstable_idx is not None and self.alpha[start_node.name].size(1) != 1: - if unstable_idx.ndim == 1: - # Only unstable neurons of the start_node neurons are used. - alpha = self.non_deter_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) - elif unstable_idx.ndim == 2: - # Each element in the batch selects different neurons. - alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) - else: - raise ValueError - else: - alpha = self.alpha[start_node.name] - - if self.init[start_node.name] == False: - lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) - lower_d_unfold = F.unfold(lower_d, self.kernel_size, 1, stride=self.stride) - - alpha_data = lower_d_unfold.view(lower_d.shape[0], lower_d.shape[1], self.kernel_size[0], self.kernel_size[1], self.output_shape[-2], self.output_shape[-1]) - alpha.data.copy_(alpha_data.permute((0,1,4,5,2,3)).clone().detach()) - self.init[start_node.name] = True - if self.padding[0] > 0: - lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] - - alpha.data = self.project_simplex(alpha.data).clone().detach() - alpha = alpha.permute((0,1,2,3,6,7,4,5)) - alpha_shape = alpha.shape - alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2], -1, alpha_shape[-2]*alpha_shape[-1])) - lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1, self.padding, self.stride) - lower_d = lower_d.view(alpha_shape[0], alpha_shape[1], alpha_shape[2], *lower_d.shape[1:]) - lower_d = lower_d.squeeze(0) - else: - lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) - if self.padding[0] > 0: - lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] - values[:] = 0.0 - max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) - values[max_upper > max_lower] = max_upper_[max_upper > max_lower] - upper_b = values + def bound_backward(self, last_lA, last_uA, *x): + if self.is_input_perturbed(1): + raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.") + + lA_y = uA_y = lA_bias = uA_bias = None + weight = x[1].lower + assert weight.size(-1) == weight.size(-2) - assert type(last_lA) == torch.Tensor or type(last_uA) == torch.Tensor - def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg): + def _bound_oneside(last_A): if last_A is None: return None, 0 - pos_A = last_A.clamp(min=0) - neg_A = last_A.clamp(max=0) + if type(last_A) is OneHotC: + # Conv layer does not support the OneHotC fast path. We have to create a dense matrix instead. + shape = last_A.shape # [spec, batch, C, H, W] + dim = int(prod(shape[2:])) + dense_last_A = torch.zeros(size=(shape[0], shape[1], dim), device=last_A.device, dtype=weight.dtype) + # last_A.index has size (spec, batch), its values are the index of the one-hot non-zero elements in A. + # last_A.coeffs is the value of the non-zero element. + dense_last_A = torch.scatter(dense_last_A, dim=2, index=last_A.index.unsqueeze(-1), src=last_A.coeffs.unsqueeze(-1)) + # We created a large A matrix and it will be handled below. + last_A = dense_last_A.view(shape[0], shape[1], *shape[2:]) - bias = 0 - if b_pos is not None: - bias = bias + self.get_bias(pos_A, b_pos) - if b_neg is not None: - bias = bias + self.get_bias(neg_A, b_neg) + if type(last_A) == Tensor: + shape = last_A.size() + next_A = F.conv2d(last_A.reshape(shape[0] * shape[1], *shape[2:]), weight, None, + stride=self.stride, padding=self.padding, dilation=self.dilation, + groups=self.groups) + next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:]) + if self.has_bias: + sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2) + else: + sum_bias = 0 + return next_A, sum_bias + elif type(last_A) == Patches: + # Here we build and propagate a Patch object with (patches, stride, padding) + assert type(last_A) == Patches + if last_A.identity == 0: + patches = last_A.patches - shape = last_A.size() - pos_A = F.interpolate(pos_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size) - pos_A = F.pad(pos_A, (0, self.input_shape[-2] - pos_A.shape[-2], 0, self.input_shape[-1] - pos_A.shape[-1])) - pos_A = pos_A.view(shape[0], shape[1], *pos_A.shape[1:]) + # FIXME: so far, assume there will be a relu layer in its input. - neg_A = F.interpolate(neg_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size) - neg_A = F.pad(neg_A, (0, self.input_shape[-2] - neg_A.shape[-2], 0, self.input_shape[-1] - neg_A.shape[-1])) - neg_A = neg_A.view(shape[0], shape[1], *neg_A.shape[1:]) + if self.has_bias: + # bias is x[2] (lower and upper are the same), and has shape (c,). + # Patches either has [out_c, batch, out_h, out_w, c, h, w] or [unstable_size, batch, c, h, w]. + sum_bias = torch.einsum('sb...chw,c->sb...', patches, x[2].lower) + # sum_bias has shape (out_c, batch, out_h, out_w) or (unstable_size, batch). + else: + sum_bias = 0 - next_A = pos_A * d_pos + neg_A * d_neg - return next_A, bias + flattened_patches = patches.reshape(-1, patches.size(-3), patches.size(-2), patches.size(-1)) + # Merge patches with this layer's weights. Weight must be flipped here; and if stride != 1, we must insert zeros in the input image. + # For conv_transpose2d, the weight matrix is in the (in, out, k, k) shape. + pieces = F.conv_transpose2d(flattened_patches, weight.transpose(0,1).flip(-1,-2), stride=last_A.inserted_zeros + 1) + # New patch size: (out_c, batch, out_h, out_w, c, h, w) or (unstable_size, batch, c, h, w). + pieces = pieces.view(*patches.shape[:-3], pieces.size(-3), pieces.size(-2), pieces.size(-1)) - if self.padding[0] > 0: - upper_d = upper_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] + elif last_A.identity == 1: + # New patches have size [out_c, batch, out_h, out_w, c, h, w] if it is not sparse. + # New patches have size [unstable_size, batch, c, h, w] if it is sparse. + if last_A.unstable_idx is not None: + raise NotImplementedError() + pieces = weight.view(weight.size(0), 1, weight.size(1), weight.size(2), weight.size(3)) + # Select based on the output channel (out_h and out_w are irrelevant here). + pieces = pieces[last_A.unstable_idx[0]] + # Expand the batch dimnension. + pieces = pieces.expand(-1, last_A.shape[1], -1, -1, -1) + # Do the same for the bias. + sum_bias = x[2].lower[last_A.unstable_idx[0]].unsqueeze(-1) + # bias has shape (unstable_size, batch). + sum_bias = sum_bias.expand(-1, last_A.shape[1]) + else: + assert weight.size(0) == last_A.shape[0] + pieces = weight.view(weight.size(0), 1, 1, 1, weight.size(1), weight.size(2), weight.size(3)).expand(-1, *last_A.shape[1:4], -1, -1, -1) + # The bias (x[2].lower) has shape (out_c,) need to make it (out_c, batch, out_h, out_w). + # Here we should transpose sum_bias to set the batch dim to 1, aiming to keep it consistent with the matrix version + sum_bias = x[2].lower.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4]) + else: + raise NotImplementedError() + padding = last_A.padding if last_A is not None else (0, 0, 0, 0) # (left, right, top, bottom) + output_padding = last_A.output_padding if last_A is not None else (0, 0, 0, 0) # (left, right, top, bottom) + inserted_zeros = last_A.inserted_zeros + assert self.padding == [0, 0] + assert self.stride[0] == self.stride[1] + + # Unify the shape to 4-tuple. + output_padding = unify_shape(output_padding) + padding = unify_shape(padding) + this_stride = unify_shape(self.stride) + this_padding = unify_shape(self.padding) + + # Compute new padding. + padding = tuple(p + (weight.size(3 - j//2) - 1) for j, p in enumerate(padding)) + + # Compute new output padding + output_padding = tuple(p * this_stride[j] + this_padding[j] for j, p in enumerate(output_padding)) + # When we run insert_zeros, it's missing the right most column and the bottom row. + # padding = (padding[0], padding[1] + inserted_zeros, padding[2], padding[3] + inserted_zeros) + + # If no transposed conv so far, inserted_zero is 0. + # We a transposed conv is encountered, stride is multiplied on it. + inserted_zeros = (inserted_zeros + 1) * this_stride[0] - 1 + + # FIXME: disabled patches_to_matrix because not all parameters are supported. + if inserted_zeros == 0 and not is_shape_used(output_padding) and pieces.shape[-1] > self.input_shape[-1]: # the patches is too large and from now on, we will use matrix mode instead of patches mode. + # This is our desired matrix: the input will be flattend to (batch_size, input_channel*input_x * input_y) and multiplies on this matrix. + # After multiplication, the desired output is (batch_size, out_channel*output_x*output_y). + # A_matrix has size (batch, out_c*out_h*out_w, in_c*in_h*in_w) + assert inserted_zeros == 0 + A_matrix = patches_to_matrix(pieces, self.input_shape[1:], last_A.stride, padding, last_A.output_shape, last_A.unstable_idx) + if isinstance(sum_bias, Tensor) and last_A.unstable_idx is None: + sum_bias = sum_bias.transpose(0, 1) + sum_bias = sum_bias.reshape(sum_bias.size(0), -1).transpose(0,1) + A_matrix = A_matrix.transpose(0,1) # Spec dimension at the front. + return A_matrix, sum_bias + return Patches(pieces, last_A.stride, padding, pieces.shape, unstable_idx=last_A.unstable_idx, + output_shape=last_A.output_shape, inserted_zeros=inserted_zeros, output_padding=output_padding), sum_bias + else: + raise NotImplementedError() + + lA_x, lbias = _bound_oneside(last_lA) + uA_x, ubias = _bound_oneside(last_uA) + return [(lA_x, uA_x), (lA_y, uA_y), (lA_bias, uA_bias)], lbias, ubias + + def interval_propagate(self, *v, C=None): + if self.is_input_perturbed(1): + raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.") + + norm = Interval.get_perturbation(v[0]) + norm = norm[0] + + h_L, h_U = v[0] + weight = v[1][0] + bias = v[2][0] if self.has_bias else None + + if norm == np.inf: + mid = (h_U + h_L) / 2.0 + diff = (h_U - h_L) / 2.0 + weight_abs = weight.abs() + deviation = F.conv_transpose2d(diff, weight_abs, None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding) + elif norm > 0: + raise NotImplementedError() + norm, eps = Interval.get_perturbation(v[0]) + # L2 norm, h_U and h_L are the same. + mid = h_U + # TODO: padding + deviation = torch.mul(weight, weight).sum((1, 2, 3)).sqrt() * eps + deviation = deviation.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + else: # Here we calculate the L0 norm IBP bound using the bound proposed in [Certified Defenses for Adversarial Patches, ICLR 2020] + raise NotImplementedError() + norm, eps, ratio = Interval.get_perturbation(v[0]) + mid = h_U + k = int(eps) + weight_sum = torch.sum(weight.abs(), 1) + deviation = torch.sum(torch.topk(weight_sum.view(weight_sum.shape[0], -1), k)[0], dim=1) * ratio - uA, ubias = _bound_oneside(last_uA, upper_d, lower_d, upper_b, lower_b) - lA, lbias = _bound_oneside(last_lA, lower_d, upper_d, lower_b, upper_b) + if self.has_bias: + center = F.conv2d(mid, weight, v[2][0], self.stride, self.padding, self.dilation, self.groups) + else: + center = F.conv2d(mid, weight, None, self.stride, self.padding, self.dilation, self.groups) - return [(lA, uA)], lbias, ubias + ss = center.shape + deviation = deviation.repeat(ss[2] * ss[3]).view(-1, ss[1]).t().view(ss[1], ss[2], ss[3]) + + center = F.conv_transpose2d(mid, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding) + + upper = center + deviation + lower = center - deviation + return lower, upper class BoundPad(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - if len(attr) == 1: - self.padding = [0, 0, 0, 0] - self.value = 0.0 - else: + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + if hasattr(attr, 'pads'): self.padding = attr['pads'][2:4] + attr['pads'][6:8] - self.value = attr['value'] + else: + self.padding = [0, 0, 0, 0] + self.value = attr.get('value', 0.0) assert self.padding == [0, 0, 0, 0] - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - @Bound.save_io_shape def forward(self, x, pad, value=0.0): # TODO: padding for 3-D or more dimensional inputs. assert x.ndim == 4 # x[1] should be [0,0,pad_top,pad_left,0,0,pad_bottom,pad_right] + assert pad[0] == pad[1] == pad[4] == pad[5] == 0 pad = [int(pad[3]), int(pad[7]), int(pad[2]), int(pad[6])] final = F.pad(x, pad, value=value) self.padding, self.value = pad, value @@ -459,7 +595,6 @@ def interval_propagate(self, *v): def bound_backward(self, last_lA, last_uA, *x): # TODO: padding for 3-D or more dimensional inputs. - pad = self.padding left, right, top, bottom = self.padding def _bound_oneside(last_A): if last_A is None: @@ -470,7 +605,7 @@ def _bound_oneside(last_A): new_padding = (last_A.padding[0] + left, last_A.padding[1] + right, last_A.padding[2] + top, last_A.padding[3] + bottom) else: new_padding = (last_A.padding + left, last_A.padding + right, last_A.padding + top, last_A.padding + bottom) - return Patches(last_A.patches, last_A.stride, new_padding, last_A.shape, last_A.identity, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape) + return last_A.create_similar(padding=new_padding) else: shape = last_A.size() return last_A[:, :, :, top:(shape[3] - bottom), left:(shape[4] - right)] @@ -478,61 +613,42 @@ def _bound_oneside(last_A): last_uA = _bound_oneside(last_uA) return [(last_lA, last_uA), (None, None), (None, None)], 0, 0 -class BoundGlobalAveragePool(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - - @Bound.save_io_shape - def forward(self, x): - output = nn.AdaptiveAvgPool2d((1, 1)).forward(x) # adaptiveAveragePool with output size (1, 1) - return output - - def bound_backward(self, last_lA, last_uA, x): - H, W = self.input_shape[-2], self.input_shape[-1] - - lA = (last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)) if last_lA is not None else None - uA = (last_uA.expand(list(last_uA.shape[:-2]) + [H, W]) / (H * W)) if last_uA is not None else None - - return [(lA, uA)], 0, 0 - - def interval_propagate(self, *v): - h_L, h_U = v[0] - h_L = F.adaptive_avg_pool2d(h_L, (1, 1)) - h_U = F.adaptive_avg_pool2d(h_U, (1, 1)) - return h_L, h_U - -class BoundAveragePool(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - # assumptions: ceil_mode=False, count_include_pad=True - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - - assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2]) - assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3]) - - self.kernel_size = attr['kernel_shape'] - self.stride = attr['strides'] - self.padding = [attr['pads'][0], attr['pads'][1]] - self.ceil_mode = False - self.count_include_pad = True - self.use_default_ibp = True + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + # e.g., last layer input gurobi vars (3,32,32) + gvars_array = np.array(v[0]) + # pre_layer_shape (1,3,32,32) + pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape + # this layer shape (1,3,35,35) + this_layer_shape = self.output_shape + # v1 = tensor([0, 0, 1, 1, 0, 0, 2, 2]) + # [0,0,pad_top,pad_left,0,0,pad_bottom,pad_right] + # => [left, right, top, bottom] + padding = [int(v[1][3]), int(v[1][7]), int(v[1][2]), int(v[1][6])] + left, right, top, bottom = padding + assert pre_layer_shape[2] + padding[0] + padding[1] == this_layer_shape[2] + assert pre_layer_shape[3] + padding[2] + padding[3] == this_layer_shape[3] + + new_layer_gurobi_vars = [] + neuron_idx = 0 + for out_chan_idx in range(this_layer_shape[1]): + out_chan_vars = [] + for out_row_idx in range(this_layer_shape[2]): + out_row_vars = [] + row_pad = out_row_idx < left or out_row_idx >= this_layer_shape[2] - right + for out_col_idx in range(this_layer_shape[3]): + col_pad = out_col_idx < top or out_col_idx >= this_layer_shape[3] - bottom + if row_pad or col_pad: + v = model.addVar(lb=0, ub=0, + obj=0, vtype=grb.GRB.CONTINUOUS, + name=f'pad{self.name}_{neuron_idx}') + else: + v = gvars_array[out_chan_idx, out_row_idx - left, out_col_idx - top] + # print(out_chan_idx, out_row_idx, out_col_idx, row_pad, col_pad, v.LB, v.UB) + neuron_idx += 1 - @Bound.save_io_shape - def forward(self, x): - return F.avg_pool2d(x, self.kernel_size, self.stride, - self.padding, self.ceil_mode, self.count_include_pad) + out_row_vars.append(v) + out_chan_vars.append(out_row_vars) + new_layer_gurobi_vars.append(out_chan_vars) - def bound_backward(self, last_lA, last_uA, x): - def _bound_oneside(last_A): - if last_A is None: - return None, 0 - shape = last_A.size() - # propagate A to the next layer, with batch concatenated together - next_A = F.interpolate(last_A.view(shape[0] * shape[1], *shape[2:]), - scale_factor=self.kernel_size) / (prod(self.kernel_size)) - next_A = F.pad(next_A, (0, self.input_shape[-2] - next_A.shape[-2], 0, self.input_shape[-1] - next_A.shape[-1])) - next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:]) - return next_A, 0 - - lA, lbias = _bound_oneside(last_lA) - uA, ubias = _bound_oneside(last_uA) - return [(lA, uA)], lbias, ubias \ No newline at end of file + self.solver_vars = new_layer_gurobi_vars + model.update() diff --git a/auto_LiRPA/operators/cut_ops.py b/auto_LiRPA/operators/cut_ops.py new file mode 100644 index 0000000..e4337ae --- /dev/null +++ b/auto_LiRPA/operators/cut_ops.py @@ -0,0 +1,588 @@ +""" Cut operators""" +from .base import * +from .clampmult import multiply_by_A_signs + + +class CutModule(): + # store under BoundedModule + def __init__(self, relu_nodes=[], general_beta=None, x_coeffs=None, + active_cuts=None, cut_bias=None): + # all dict, storing cut parameters for each start node + # {start node name: (2 (lA, uA), spec (out_c, out_h, out_w), batch, num_cuts)} + self.general_beta = general_beta + # {start node name: (# active cut constraints,)} + self.active_cuts = active_cuts + + # all dict with tensor, storing coeffs for each relu layer, no grad + # coeffs: {relu layername: (num_cuts, flattened_nodes)} + self.relu_coeffs, self.arelu_coeffs, self.pre_coeffs = {}, {}, {} + for m in relu_nodes: + self.relu_coeffs[m.name] = self.arelu_coeffs[m.name] = self.pre_coeffs[m.name] = None + + # single tensor, always the same, no grad + # bias: (num_cuts,) + self.cut_bias = cut_bias + # x_coeffs: (num_cuts, flattened input dims) + self.x_coeffs = x_coeffs + + def use_patches(self, start_node): + # check if we are using patches mode for the start node + A = start_node.lA if start_node.lA is not None else start_node.uA + return type(A) is Patches + + def select_active_general_beta(self, start_node, unstable_idx=None): + # if one constraint have nodes deeper than start node, we do not count its effect for now + # self.general_beta[start_node.name]: (2(0 lower, 1 upper), spec (out_c, out_h, out_w/# fc nodes), batch, num_constrs) + # self.active_cuts[start_node.name]: a long() tensor with constraint index that + # should be index on current layer with current start node + if self.general_beta[start_node.name].ndim == 4: + general_beta = self.general_beta[start_node.name][:, :, :, self.active_cuts[start_node.name]] + elif self.general_beta[start_node.name].ndim == 6: + general_beta = self.general_beta[start_node.name][:, :, :, :, :, self.active_cuts[start_node.name]] + else: + print("general beta shape not supported!") + exit() + if unstable_idx is not None: + if self.use_patches(start_node): + general_beta = general_beta[:, unstable_idx[0], unstable_idx[1], unstable_idx[2], :, :] + else: + # matrix mode + if general_beta.ndim == 6: + # conv layers general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs) + _, out_c, out_h, out_w, batch, num_constrs = general_beta.shape + general_beta = general_beta.view(2, -1, batch, num_constrs) + else: + # dense layers general_beta: (2(0 lower, 1 upper), spec, batch, num_constrs) + pass + general_beta = general_beta[:, unstable_idx] + else: + # unstable_idx is None + if general_beta.ndim == 6: + # flatten spec layer shape + _, out_c, out_h, out_w, batch, num_constrs = general_beta.shape + general_beta = general_beta.view(2, -1, batch, num_constrs) + return general_beta + + def general_beta_coeffs_mm(self, unstable_spec_beta, coeffs, A, current_layer_shape): + if type(A) is Patches: + # lA, uA are patches, we have to unfold beta and coeffs to match lA and uA + # coeffs: (num_constrs, current_c, current_h, current_w) + # coeffs_unfolded: (num_constrs, out_h, out_w, in_c, H, W) + # current_layer_shape = x.lower.size()[1:] + coeffs_unfolded = inplace_unfold(coeffs.view(-1, *current_layer_shape), \ + kernel_size=A.patches.shape[-2:], padding=A.padding, stride=A.stride) + # unstable_coeffs_unfolded: (num_constrs, unstable, in_c, H, W) + # A.unstable_idx is the unstable idx for spec layer + unstable_coeffs_unfolded = coeffs_unfolded[:, A.unstable_idx[1], A.unstable_idx[2], :, :, :] + # A.unstable_idx: unstable index on out_c, out_h and out_w + # general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs) + # unstable_spec_beta: (2(0 lower, 1 upper), unstable, batch, num_constrs) + # unstable_spec_beta = general_beta[:, A.unstable_idx[0],\ + # A.unstable_idx[1], A.unstable_idx[2], :, :] + # beta_mm_coeffs_unfolded: (2(0 lower, 1 upper), unstable, batch, in_c, H, W) + beta_mm_coeffs = torch.einsum('sihj,jiabc->sihabc', unstable_spec_beta, unstable_coeffs_unfolded) + else: + # unstable_spec_beta: (2(0 lower, 1 upper), unstable, batch, num_constrs) + # coeffs: (num_constrs, current flattened layer nodes) + # beta_mm_coeffs: (2(0 lower, 1 upper), unstable, batch, current flattened layer nodes) + beta_mm_coeffs = torch.einsum('sihj,ja->siha', unstable_spec_beta, coeffs) + assert beta_mm_coeffs[0].numel() == A.numel(), f"the shape of beta is not initialized correctly! {beta_mm_coeffs[0].shape} v.s. {A.shape}" + return beta_mm_coeffs.reshape(2, *A.shape) + + def general_beta_coeffs_addmm_to_A(self, lA, uA, general_beta, coeffs, current_layer_shape): + A = lA if lA is not None else uA + # general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs) + # coeffs: (num_constrs, current_c, current_h, current_w) + # beta_mm_coeffs[0] shape is the same as A + # patches mode: (2(0 lower, 1 upper), unstable, batch, in_c, H, W) + # not patches: (2(0 lower, 1 upper), unstable, batch, current flattened layer nodes) + beta_mm_coeffs = self.general_beta_coeffs_mm(general_beta, coeffs, A, current_layer_shape) + assert beta_mm_coeffs[0].shape == A.shape + if type(A) is Patches: + # lA, uA are patches, we have to unfold beta and coeffs to match lA and uA + # lA_patches: (unstable, batch, in_c, H, W) + if lA is not None: + lA = Patches(lA.patches - beta_mm_coeffs[0], A.stride, A.padding, \ + A.patches.shape, unstable_idx=A.unstable_idx, output_shape=A.output_shape) + if uA is not None: + uA = Patches(uA.patches + beta_mm_coeffs[1], A.stride, A.padding, \ + A.patches.shape, unstable_idx=A.unstable_idx, output_shape=A.output_shape) + else: + # dense layers + if lA is not None: + lA = lA - beta_mm_coeffs[0] + if uA is not None: + uA = uA + beta_mm_coeffs[1] + return lA, uA + + def patch_trick(self, start_node, layer_name, A, current_layer_shape): + ######## A problem with patches mode for cut constraint start ########## + # There are cases that the node that is in the constraint but not selected by the patches for the output node + # trick: only count the small patches that have all the split node coeffs[ci].sum() equal to coeffs_unfolded[ci][out_h, out_w, -1].sum() + # we should force these beta to be 0 to disable the effect of these constraints + # this only apply if current layer uses patches mode; if the target layer is patches but current layer not, we should not use it! + assert type(A) is Patches, "this trick fix only works for patches mode" + # unstable_spec_beta stores the current propagation, self.general_beta[start_node.name] selected with active_cuts, spec unstable + coeffs = 0 + if layer_name != "input": + if self.relu_coeffs[layer_name] is not None: + coeffs = coeffs + self.relu_coeffs[layer_name] + if self.arelu_coeffs[layer_name] is not None: + coeffs = coeffs + self.arelu_coeffs[layer_name] + if self.pre_coeffs[layer_name] is not None: + coeffs = coeffs + self.pre_coeffs[layer_name] + else: + if self.x_coeffs is not None: + coeffs = coeffs + self.x_coeffs + coeffs_unfolded = inplace_unfold(coeffs.view(-1, *current_layer_shape), \ + kernel_size=A.patches.shape[-2:], padding=A.padding, stride=A.stride) + num_constrs, out_h, out_w, in_c, H, W = coeffs_unfolded.shape + # make sure the small patch selected include all the nonzero coeffs + ####### NOTE: This check could be costly ####### + patch_mask_on_beta = (coeffs_unfolded.reshape(num_constrs, out_h, out_w, -1).abs().sum(-1) == \ + coeffs.reshape(num_constrs, -1).abs().sum(-1).reshape(num_constrs, 1, 1)) + # patch_mask_on_beta: (out_h, out_w, num_constrs) + patch_mask_on_beta = patch_mask_on_beta.permute(1, 2, 0) + # 2(lower, upper), out_c, out_h, out_w, batch, num_constrs + patch_mask_on_beta = patch_mask_on_beta.reshape(1, 1, out_h, out_w, 1, num_constrs) + self.general_beta[start_node.name].data = self.general_beta[start_node.name].data * patch_mask_on_beta + + def relu_cut(self, start_node, layer_name, last_lA, last_uA, current_layer_shape, unstable_idx=None, batch_mask=None): + # propagate relu neuron in cut constraints through relu layer + # start_node.name in self.general_beta means there are intermediate betas that can optimize this start node separately + relu_coeffs = self.relu_coeffs[layer_name] + active_cuts = self.active_cuts[start_node.name] + # active_cuts.size(0) == 0 means all constraints containing this layer have deep layer nodes + if relu_coeffs is None or active_cuts.size(0) == 0: + # do nothing + return last_lA, last_uA + assert start_node.name in self.general_beta + # select current relu layer general beta + general_beta = self.select_active_general_beta(start_node, unstable_idx) + relu_coeffs = relu_coeffs[active_cuts] + if batch_mask is not None: + general_beta = general_beta[:, :, batch_mask] + last_lA, last_uA = self.general_beta_coeffs_addmm_to_A(last_lA, last_uA, general_beta, + relu_coeffs, current_layer_shape) + return last_lA, last_uA + + def pre_cut(self, start_node, layer_name, lA, uA, current_layer_shape, unstable_idx=None, batch_mask=None): + # propagate prerelu neuron in cut constraints through relu layer + # start_node.name in self.general_beta means there are intermediate betas that can optimize this start node separately + pre_coeffs = self.pre_coeffs[layer_name] + active_cuts = self.active_cuts[start_node.name] + # active_cuts.size(0) == 0 means all constraints containing this layer have deep layer nodes + if pre_coeffs is None or active_cuts.size(0) == 0: + # do nothing + return lA, uA + general_beta = self.select_active_general_beta(start_node, unstable_idx) + pre_coeffs = pre_coeffs[active_cuts] + if batch_mask is not None: + general_beta = general_beta[:, :, batch_mask] + lA, uA = self.general_beta_coeffs_addmm_to_A(lA, uA, general_beta, pre_coeffs, current_layer_shape) + return lA, uA + + + @staticmethod + @torch.jit.script + def jit_arelu_lA(last_lA, lower, upper, beta_mm_coeffs, unstable_or_cut_index, upper_d): + nu_hat_pos = last_lA.clamp(max=0.).abs() + tao = (-lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[0]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10) + pi = (upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[0]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10) + tao, pi = tao.clamp(min=0.), pi.clamp(min=0.) + tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos) + new_upper_d = pi / (pi + tao + 1e-10) + # need to customize the upper bound slope and lbias for (1) unstable relus and + # (2) relus that are used with upper boundary relaxation + # original upper bound slope is u/(u-l) also equal to pi/(pi+tao) if no beta_mm_coeffs[0] + # now the upper bound slope should be pi/(p+tao) updated with beta_mm_coeffs[0] + unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(last_lA < 0) + # conv layer: + # upper_d: 1, batch, current_c, current_w, current_h + # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current_c, current_w, current_h + # dense layer: + # upper_d: 1, batch, current flattened nodes + # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current flattened nodes + new_upper_d = new_upper_d * unstable_upper_bound_index.float() + \ + upper_d * (1. - unstable_upper_bound_index.float()) + return nu_hat_pos, tao, pi, new_upper_d, unstable_upper_bound_index + + @staticmethod + @torch.jit.script + def jit_arelu_lbias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, lower, upper, lbias, pi, tao): + # if no unstable, following bias should always be 0 + if unstable_or_cut_index.sum() > 0: + # update lbias with new form, only contribued by unstable relus + uC = -upper.unsqueeze(0) * nu_hat_pos + lC = -lower.unsqueeze(0) * nu_hat_pos + # lbias: (spec unstable, batch, current flattened nodes) same as lA + lbias = (pi * lower.unsqueeze(0)) + + uC_mask = (beta_mm_coeffs[0] <= uC).to(lbias) + lC_mask = (beta_mm_coeffs[0] >= lC).to(lbias) + default_mask = ((1-uC_mask) * (1-lC_mask)).to(lbias) + lbias = - beta_mm_coeffs[0].to(lbias) * lC_mask + lbias * default_mask + + # lbias[beta_mm_coeffs[0] <= uC] = 0. + # lbias[beta_mm_coeffs[0] >= lC] = -beta_mm_coeffs[0][beta_mm_coeffs[0] >= lC].to(lbias) + + # final lbias: (spec unstable, batch) + lbias = (lbias * unstable_or_cut_index.unsqueeze(0).float()).view(lbias.shape[0], lbias.shape[1], -1).sum(-1) + return lbias + + @staticmethod + @torch.jit.script + def jit_arelu_uA(last_uA, lower, upper, beta_mm_coeffs, unstable_or_cut_index, upper_d): + nu_hat_pos = (-last_uA).clamp(max=0.).abs() + tao = (- lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[1]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10) + pi = (upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[1]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10) + tao, pi = tao.clamp(min=0.), pi.clamp(min=0.) + tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos) + new_upper_d = pi / (pi + tao + 1e-10) + + # assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos" + + # unstable_or_cut_index = self.I.logical_or(self.arelu_coeffs.sum(0).view(self.I.shape) != 0) + unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(-last_uA < 0) + new_upper_d = new_upper_d * unstable_upper_bound_index.float() + \ + upper_d * (1. - unstable_upper_bound_index.float()) + return nu_hat_pos, tao, pi, new_upper_d, unstable_upper_bound_index + + @staticmethod + @torch.jit.script + def jit_arelu_ubias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, lower, upper, ubias, pi, tao): + if unstable_or_cut_index.sum() > 0: + uC = -upper.unsqueeze(0) * nu_hat_pos + lC = -lower.unsqueeze(0) * nu_hat_pos + ubias = -(pi * lower.unsqueeze(0)) + + uC_mask = (beta_mm_coeffs[1] <= uC).to(ubias) + lC_mask = (beta_mm_coeffs[1] >= lC).to(ubias) + default_mask = ((1-uC_mask) * (1-lC_mask)).to(ubias) + ubias = beta_mm_coeffs[1].to(ubias) * lC_mask + ubias * default_mask + + # ubias[beta_mm_coeffs[1] <= uC] = 0. + # ubias[beta_mm_coeffs[1] >= lC] = beta_mm_coeffs[1][beta_mm_coeffs[1] >= lC].to(ubias) + + ubias = (ubias * unstable_or_cut_index.unsqueeze(0).float()).view(ubias.shape[0], ubias.shape[1], -1).sum(-1) + return ubias + + + def arelu_cut(self, start_node, layer_name, last_lA, last_uA, lower_d, upper_d, + lower_b, upper_b, lb_lower_d, ub_lower_d, I, x, patch_size, + current_layer_shape, unstable_idx=None, batch_mask=None): + # propagate integer var of relu neuron (arelu) in cut constraints through relu layer + arelu_coeffs = self.arelu_coeffs[layer_name] + active_cuts = self.active_cuts[start_node.name] + # active_cuts.size(0) == 0 means all constraints containing this layer have deep layer nodes + if arelu_coeffs is None or active_cuts.size(0) == 0: + # do regular propagation without cut + uA, ubias = _bound_oneside(last_uA, upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, lower_b, start_node, patch_size) + lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, upper_d, lower_b, upper_b, start_node, patch_size) + return lA, uA, lbias, ubias + + # general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs) + general_beta = self.select_active_general_beta(start_node, unstable_idx) + # arelu_coeffs: (num_constrs, flattened current layer nodes) + arelu_coeffs = arelu_coeffs[active_cuts] + if batch_mask is not None: + general_beta = general_beta[:, :, batch_mask] + A = last_lA if last_lA is not None else last_uA + # beta_mm_coeffs[0] shape is the same as A + # patches mode: (2(0 lower, 1 upper), unstable, batch, in_c, H, W) + # not patches: (2(0 lower, 1 upper), unstable, batch, current flattened layer nodes) + beta_mm_coeffs = self.general_beta_coeffs_mm(general_beta, arelu_coeffs, A, current_layer_shape) + # unstable_this_layer = torch.logical_and(x.lower < 0, x.upper > 0).unsqueeze(0) + # I is the unstable index in this relu layer: (batch, *layer shape) + # if there is node in cut constraint that is stable, also need to count its effect + # self.arelu_coeffs: (num_constrs, flattened current layer) + unstable_or_cut_index = I.logical_or(arelu_coeffs.sum(0).view(I[0:1].shape) != 0) + + if type(A) is Patches: + # patches mode, conv layer only + # x.lower (always regular shape): batch, current_c, current_h, current_w + # x_lower_unfold: unstable, batch, in_C, H, W (same as patches last_lA) + x_lower_unfold = _maybe_unfold(x.lower.unsqueeze(0), A) + x_upper_unfold = _maybe_unfold(x.upper.unsqueeze(0), A) + # first minus upper and lower and then unfold to patch size will save memory + x_upper_minus_lower_unfold = _maybe_unfold((x.upper - x.lower).unsqueeze(0), A) + ####### be careful with the unstable_this_layer and unstable_idx ####### + # unstable_this_layer is the unstable index in this layer + # unstable_idx is the unstable index in spec layer + # unstable_this_layer: spec unstable, batch, in_C, H, W (same as patches last_lA) + # unstable_this_layer = torch.logical_and(x_lower_unfold < 0, x_upper_unfold > 0) + # unstable_this_layer = _maybe_unfold(self.I.unsqueeze(0), A) + unstable_or_cut_index = _maybe_unfold(unstable_or_cut_index.unsqueeze(0), A) + if last_lA is not None: + assert beta_mm_coeffs[0].shape == last_lA.shape, f"{beta_mm_coeffs[0].shape} != {last_lA.shape}" + # last_lA.patches, nu_hat_pos, tao, pi: (unstable, batch, in_c, H, W) + nu_hat_pos = last_lA.patches.clamp(max=0.).abs() + tao = (-x_lower_unfold * nu_hat_pos - beta_mm_coeffs[0]) / (x_upper_minus_lower_unfold.clamp(min=1e-10)) + pi = (x_upper_unfold * nu_hat_pos + beta_mm_coeffs[0]) / (x_upper_minus_lower_unfold.clamp(min=1e-10)) + tao, pi = tao.clamp(min=0.), pi.clamp(min=0.) + tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos) + new_upper_d = pi / (pi + tao + 1e-10) + + assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos" + + # unstable_upper_bound_index: spec unstable, batch, in_C, H, W (same as patches last_lA) + unstable_upper_bound_index = unstable_or_cut_index.logical_and(last_lA.patches < 0) + # upper_d: (spec unstable, 1, in_C, H, W) (unfolded shape, same as patches last_lA) + new_upper_d = new_upper_d * unstable_upper_bound_index.float() + \ + upper_d * (1. - unstable_upper_bound_index.float()) + + if last_uA is None: uA, ubias = None, 0. + # lbias: unstable, batch + # lA: unstable, batch, in_C, H, W (same as patches last_lA) + lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, new_upper_d, lower_b, upper_b, start_node, patch_size) + + # if general_beta[0].sum()!=0: import pdb; pdb.set_trace() + # there is any unstable relus in this layer + if unstable_or_cut_index.sum() > 0: + uC = -x_upper_unfold * nu_hat_pos + lC = -x_lower_unfold * nu_hat_pos + lbias = (pi * x_lower_unfold) + lbias[beta_mm_coeffs[0] <= uC] = 0. + lbias[beta_mm_coeffs[0] >= lC] = -beta_mm_coeffs[0][beta_mm_coeffs[0] >= lC].to(lbias) + # lbias: unstable, batch, in_C, H, W (same as patches last_lA) => lbias: (unstable, batch) + lbias = (lbias * unstable_or_cut_index.float()).view(lbias.shape[0], lbias.shape[1], -1).sum(-1) + + if last_uA is not None: + # get the upper bound + nu_hat_pos = (-last_uA.patches).clamp(max=0.).abs() + tao = (-x_lower_unfold * nu_hat_pos - beta_mm_coeffs[1]) / (x_upper_minus_lower_unfold + 1e-10) + pi = (x_upper_unfold * nu_hat_pos + beta_mm_coeffs[1]) / (x_upper_minus_lower_unfold + 1e-10) + tao, pi = tao.clamp(min=0.), pi.clamp(min=0.) + tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos) + new_upper_d = pi / (pi + tao + 1e-10) + + assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos" + + unstable_upper_bound_index = unstable_or_cut_index.logical_and((-last_uA.patches) < 0) + new_upper_d = new_upper_d * unstable_upper_bound_index.float() + \ + upper_d * (1. - unstable_upper_bound_index.float()) + + uA, ubias = _bound_oneside(last_uA, new_upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, lower_b, start_node, patch_size) + if last_lA is None: lA, lbias = None, 0. + + if unstable_or_cut_index.sum() > 0: + uC = -x_upper_unfold * nu_hat_pos + lC = -x_lower_unfold * nu_hat_pos + ubias = -(pi * x_lower_unfold) + ubias[beta_mm_coeffs[1] <= uC] = 0. + ubias[beta_mm_coeffs[1] >= lC] = beta_mm_coeffs[1][beta_mm_coeffs[1] >= lC].to(ubias) + # ubias: unstable, batch, in_C, H, W (same as patches last_uA) => ubias: (unstable, batch) + ubias = (ubias * unstable_or_cut_index.float()).view(ubias.shape[0], ubias.shape[1], -1).sum(-1) + else: + # dense + if last_lA is not None: + # ##################### + # # C is nu_hat_pos + # # last_lA: (spec unstable, batch, current flattened nodes (current_c*current_h*current_w)) + # nu_hat_pos = last_lA.clamp(max=0.).abs() + # # pi, tao: spec_unstable, batch, current layer shape (same as last_lA) + # tao = (-x.lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[0]) / (x.upper.unsqueeze(0) - x.lower.unsqueeze(0) + 1e-10) + # pi = (x.upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[0]) / (x.upper.unsqueeze(0) - x.lower.unsqueeze(0) + 1e-10) + # tao, pi = tao.clamp(min=0.), pi.clamp(min=0.) + # tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos) + # new_upper_d = pi / (pi + tao + 1e-10) + + # assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos" + + # # need to customize the upper bound slope and lbias for (1) unstable relus and + # # (2) relus that are used with upper boundary relaxation + # # original upper bound slope is u/(u-l) also equal to pi/(pi+tao) if no beta_mm_coeffs[0] + # # now the upper bound slope should be pi/(p+tao) updated with beta_mm_coeffs[0] + # unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(last_lA < 0) + # # conv layer: + # # upper_d: 1, batch, current_c, current_w, current_h + # # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current_c, current_w, current_h + # # dense layer: + # # upper_d: 1, batch, current flattened nodes + # # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current flattened nodes + # new_upper_d = new_upper_d * unstable_upper_bound_index.float() +\ + # upper_d * (1. - unstable_upper_bound_index.float()) + # ##################### + + nu_hat_pos, tao, pi, new_upper_d, unstable_upper_bound_index = self.jit_arelu_lA(last_lA, x.lower, x.upper, beta_mm_coeffs, unstable_or_cut_index, upper_d) + + if last_uA is None: uA, ubias = None, 0. + lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, new_upper_d, lower_b, upper_b, start_node, patch_size) + + # if unstable_or_cut_index.sum() == 0: assert (lbias == 0).all(), "lbias should be 0 if no unstable relus" + + # ##################### + # # if no unstable, following bias should always be 0 + # if unstable_or_cut_index.sum() > 0: + # # update lbias with new form, only contribued by unstable relus + # uC = -x.upper.unsqueeze(0) * nu_hat_pos + # lC = -x.lower.unsqueeze(0) * nu_hat_pos + # # lbias: (spec unstable, batch, current flattened nodes) same as lA + # lbias = (pi * x.lower.unsqueeze(0)) + # lbias[beta_mm_coeffs[0] <= uC] = 0. + # lbias[beta_mm_coeffs[0] >= lC] = -beta_mm_coeffs[0][beta_mm_coeffs[0] >= lC].to(lbias) + # # final lbias: (spec unstable, batch) + # lbias = (lbias * unstable_or_cut_index.unsqueeze(0).float()).view(lbias.shape[0], lbias.shape[1], -1).sum(-1) + # ##################### + lbias = self.jit_arelu_lbias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, x.lower, x.upper, lbias, pi, tao) + + if last_uA is not None: + # # C is nu_hat_pos + # nu_hat_pos = (-last_uA).clamp(max=0.).abs() + # tao = (- x.lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[1]) / (x.upper.unsqueeze(0) - x.lower.unsqueeze(0) + 1e-10) + # pi = (x.upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[1]) / (x.upper.unsqueeze(0) - x.lower.unsqueeze(0) + 1e-10) + # tao, pi = tao.clamp(min=0.), pi.clamp(min=0.) + # tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos) + # new_upper_d = pi / (pi + tao + 1e-10) + + # assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos" + + # # unstable_or_cut_index = self.I.logical_or(self.arelu_coeffs.sum(0).view(self.I.shape) != 0) + # unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(-last_uA < 0) + # new_upper_d = new_upper_d * unstable_upper_bound_index.float() +\ + # upper_d * (1. - unstable_upper_bound_index.float()) + nu_hat_pos, tao, pi, new_upper_d, unstable_upper_bound_index = self.jit_arelu_uA(last_uA, x.lower, x.upper, beta_mm_coeffs, unstable_or_cut_index, upper_d) + + # one can test uA by optimize -obj which should have the same obj value + uA, ubias = _bound_oneside(last_uA, new_upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, lower_b, start_node, patch_size) + if last_lA is None: lA, lbias = None, 0. + + # if unstable_or_cut_index.sum() == 0: assert ubias == 0, "ubias should be 0 if no unstable relus" + + # if unstable_or_cut_index.sum() > 0: + # uC = -x.upper.unsqueeze(0) * nu_hat_pos + # lC = -x.lower.unsqueeze(0) * nu_hat_pos + # ubias = -(pi * x.lower.unsqueeze(0)) + # ubias[beta_mm_coeffs[1] <= uC] = 0. + # ubias[beta_mm_coeffs[1] >= lC] = beta_mm_coeffs[1][beta_mm_coeffs[1] >= lC].to(ubias) + # ubias = (ubias * unstable_or_cut_index.unsqueeze(0).float()).view(ubias.shape[0], ubias.shape[1], -1).sum(-1) + + ubias = self.jit_arelu_ubias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, x.lower, x.upper, ubias, pi, tao) + + return lA, uA, lbias, ubias + + def input_cut(self, start_node, lA, uA, current_layer_shape, unstable_idx=None, batch_mask=None): + # propagate input neuron in cut constraints through relu layer + active_cuts = self.active_cuts[start_node.name] + if self.x_coeffs is None or active_cuts.size(0) == 0: + return lA, uA + + if type(lA) is Patches: + A = lA if lA is not None else uA + self.patch_trick(start_node, "input", A, current_layer_shape) + + general_beta = self.select_active_general_beta(start_node, unstable_idx) + x_coeffs = self.x_coeffs[active_cuts] + if batch_mask is not None: + general_beta = general_beta[:, :, batch_mask] + # general_beta: (2(0 lower, 1 upper), spec, batch, num_constrs) + # x_coeffs: (num_constrs, flattened input dims) + # beta_bias: (2(0 lower, 1 upper), batch, spec) + lA, uA = self.general_beta_coeffs_addmm_to_A(lA, uA, general_beta, x_coeffs, current_layer_shape) + return lA, uA + + def bias_cut(self, start_node, lb, ub, unstable_idx=None, batch_mask=None): + active_cuts = self.active_cuts[start_node.name] + if self.cut_bias is None or active_cuts.size(0) == 0: + return lb, ub + bias_coeffs = self.cut_bias[active_cuts] + general_beta = self.select_active_general_beta(start_node, unstable_idx) + if batch_mask is not None: + general_beta = general_beta[:, :, batch_mask] + # add bias for the bias term of general cut + # general_beta: (2(0 lower, 1 upper), spec, batch, num_constrs) + # bias_coeffs: (num_constrs,) + # beta_bias: (2(0 lower, 1 upper), batch, spec) + beta_bias = torch.einsum('sihj,j->shi', general_beta, bias_coeffs) + lb = lb + beta_bias[0] if lb is not None else None + ub = ub - beta_bias[1] if ub is not None else None + return lb, ub + + +# Choose upper or lower bounds based on the sign of last_A +# this is a copy from activation.py +def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg, start_node, patch_size): + if last_A is None: + return None, 0 + if type(last_A) == Tensor: + A, bias = multiply_by_A_signs(last_A, d_pos, d_neg, b_pos, b_neg, contiguous=True) + return A, bias + elif type(last_A) == Patches: + # if last_A is not an identity matrix + assert last_A.identity == 0 + if last_A.identity == 0: + # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension. + # or (unstable_size, batch_size, in_c, H, W) when it is sparse. + patches = last_A.patches + patches_shape = patches.shape + if len(patches_shape) == 6: + patches = patches.view(*patches_shape[:2], -1, *patches_shape[-2:]) + if d_pos is not None: + d_pos = d_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) + if d_neg is not None: + d_neg = d_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) + if b_pos is not None: + b_pos = b_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) + if b_neg is not None: + b_neg = b_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) + A_prod, bias = multiply_by_A_signs(patches, d_pos, d_neg, b_pos, b_neg) + # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse. + # For sparse patches the return bias size is (unstable_size, batch). + # For regular patches the return bias size is (spec, batch, out_h, out_w). + if len(patches_shape) == 6: + A_prod = A_prod.view(*patches_shape) + # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters. + if start_node is not None: + if last_A.unstable_idx is not None: + # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w). + patch_size[start_node.name] = [last_A.output_shape[1], A_prod.size(1), last_A.output_shape[2], last_A.output_shape[3], A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)] + else: + # Regular patches. + patch_size[start_node.name] = A_prod.size() + return Patches(A_prod, last_A.stride, last_A.padding, A_prod.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape), bias + + +# In patches mode, we need to unfold lower and upper slopes. In matrix mode we simply return. +# this is a copy from activation.py +def _maybe_unfold(d_tensor, last_A): + # d_tensor (out_c, current_c, current_h, current_w): out_c shared the same alpha for spec layer + if d_tensor is None: + return None + # if mode == "matrix" or d_tensor is None or last_A is None: + if type(last_A) is not Patches or d_tensor is None or last_A is None: + return d_tensor + # Input are slopes with shape (spec, batch, input_c, input_h, input_w) + # Here spec is the same as out_c. + # assert d_tensor.ndim == 5 + origin_d_shape = d_tensor.shape + if d_tensor.ndim == 6: + d_tensor = d_tensor.view(*origin_d_shape[:2], -1, *origin_d_shape[-2:]) + d_shape = d_tensor.size() + # Reshape to 4-D tensor to unfold. + d_tensor = d_tensor.view(-1, *d_tensor.shape[-3:]) + # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W). + d_unfolded = inplace_unfold(d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding) + # Reshape to (spec, batch, out_h, out_w, in_c, H, W); here spec_size is out_c. + d_unfolded_r = d_unfolded.view(*d_shape[:-3], *d_unfolded.shape[1:]) + if last_A.unstable_idx is not None: + if d_unfolded_r.size(0) == 1: + if len(last_A.unstable_idx) == 3: + # Broadcast the spec shape, so only need to select the reset dimensions. + # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W). + d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5) + d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] + elif len(last_A.unstable_idx) == 4: + # [spec, batch, output_h, output_w, input_c, H, W] + # to [output_h, output_w, batch, in_c, H, W] + d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5) + d_unfolded_r = d_unfolded_r[last_A.unstable_idx[2], last_A.unstable_idx[3]] + else: + raise NotImplementedError() + # output shape: (unstable_size, batch, in_c, H, W). + else: + d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]] + # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W). + # For regular patches, the shape after unfold is (spec, batch, out_h, out_w, in_c, H, W). + if d_unfolded_r.ndim != last_A.patches.ndim: + d_unfolded_r = d_unfolded_r.unsqueeze(2).unsqueeze(-4) + return d_unfolded_r diff --git a/auto_LiRPA/operators/dropout.py b/auto_LiRPA/operators/dropout.py index 97c803e..d4601a6 100644 --- a/auto_LiRPA/operators/dropout.py +++ b/auto_LiRPA/operators/dropout.py @@ -1,39 +1,69 @@ from .base import * class BoundDropout(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.dropout = nn.Dropout(p=attr['ratio']) - self.scale = 1 / (1 - attr['ratio']) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + if 'ratio' in attr: + self.ratio = attr['ratio'] + self.dynamic = False + else: + self.ratio = None + self.dynamic = True + self.clear() + + def clear(self): + self.mask = None + + def forward(self, *inputs): + x = inputs[0] + if not self.training: + return x + if self.dynamic: + # Inputs: data, ratio (optional), training_mode (optional) + # We assume ratio must exist in the inputs. + # We ignore training_mode, but will use self.training which can be + # changed after BoundedModule is built. + assert inputs[1].dtype == torch.float32 + self.ratio = inputs[1] + if self.ratio >= 1: + raise ValueError('Ratio in dropout should be less than 1') + self.mask = torch.rand(x.shape) > self.ratio + return x * self.mask / (1 - self.ratio) - @Bound.save_io_shape - def forward(self, x): - res = self.dropout(x) - self.mask = res == 0 - return res + def _check_forward(self): + """ If in the training mode, a forward pass should have been called.""" + if self.training and self.mask is None: + raise RuntimeError('For a model with dropout in the training mode, '\ + 'a clean forward pass must be called before bound computation') - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, *args): + empty_A = [(None, None)] * (len(args) -1) + if not self.training: + return [(last_lA, last_uA), *empty_A], 0, 0 + self._check_forward() def _bound_oneside(last_A): if last_A is None: return None - return torch.where(self.mask.unsqueeze(0), torch.tensor(0).to(last_A), last_A * self.scale) + return last_A * self.mask / (1 - self.ratio) lA = _bound_oneside(last_lA) uA = _bound_oneside(last_uA) - return [(lA, uA)], 0, 0 + return [(lA, uA), *empty_A], 0, 0 - def bound_forward(self, dim_in, x): - assert (torch.min(self.mask) >= 0) - lw = x.lw * self.mask.unsqueeze(1) - lb = x.lb * self.mask - uw = x.uw * self.mask.unsqueeze(1) - ub = x.ub * self.mask + def bound_forward(self, dim_in, x, *args): + if not self.training: + return x + self._check_forward() + lw = x.lw * self.mask.unsqueeze(1) / (1 - self.ratio) + lb = x.lb * self.mask / (1 - self.ratio) + uw = x.uw * self.mask.unsqueeze(1) / (1 - self.ratio) + ub = x.ub * self.mask / (1 - self.ratio) return LinearBound(lw, lb, uw, ub) def interval_propagate(self, *v): - h_L, h_U = v[0] if not self.training: - return h_L, h_U - else: - lower = torch.where(self.mask, torch.tensor(0).to(h_L), h_L * self.scale) - upper = torch.where(self.mask, torch.tensor(0).to(h_U), h_U * self.scale) - return lower, upper \ No newline at end of file + return v[0] + self._check_forward() + h_L, h_U = v[0] + lower = h_L * self.mask / (1 - self.ratio) + upper = h_U * self.mask / (1 - self.ratio) + return lower, upper \ No newline at end of file diff --git a/auto_LiRPA/operators/dtype.py b/auto_LiRPA/operators/dtype.py index 1ed293e..d5dfa38 100644 --- a/auto_LiRPA/operators/dtype.py +++ b/auto_LiRPA/operators/dtype.py @@ -1,8 +1,9 @@ from .base import * +from ..utils import Patches class BoundCast(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.to = attr['to'] self.data_types = [ None, torch.float, torch.uint8, torch.int8, @@ -14,17 +15,25 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio assert self.type is not None self.use_default_ibp = True - @Bound.save_io_shape def forward(self, x): self.type_in = x.dtype return x.to(self.type) def bound_backward(self, last_lA, last_uA, x): - lA = last_lA.to(self.type_in) if last_lA is not None else None - uA = last_uA.to(self.type_in) if last_uA is not None else None + if type(last_lA) == Tensor or type(last_lA) == Tensor: + lA = last_lA.to(self.type_in) if last_lA is not None else None + uA = last_uA.to(self.type_in) if last_uA is not None else None + else: + if last_lA is not None: + lA = Patches(last_lA.patches.to(self.type_in), last_lA.stride, last_lA.padding, last_lA.shape, last_lA.identity, last_lA.unstable_idx, last_lA.output_shape) + if last_uA is not None: + uA = Patches(last_uA.patches.to(self.type_in), last_uA.stride, last_uA.padding, last_uA.shape, last_uA.identity, last_uA.unstable_idx, last_uA.output_shape) return [(lA, uA)], 0, 0 def bound_forward(self, dim_in, x): return LinearBound( x.lw.to(self.type), x.lb.to(self.type), x.uw.to(self.type), x.ub.to(self.type)) + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.forward(v[0]) diff --git a/auto_LiRPA/operators/leaf.py b/auto_LiRPA/operators/leaf.py index 150e7f0..36c66b2 100644 --- a/auto_LiRPA/operators/leaf.py +++ b/auto_LiRPA/operators/leaf.py @@ -1,14 +1,17 @@ -""" Leaf nodes (indepedent nodes in the auto_LiRPA paper), -including input, parameter, buffer, etc.""" +""" Leaf nodes (indepedent nodes in the auto_LiRPA paper). + +Including input, parameter, buffer, etc.""" from .base import * class BoundInput(Bound): - def __init__(self, input_name, name, ori_name, value, perturbation=None): - super().__init__(input_name, name, ori_name) + def __init__(self, ori_name, value, perturbation=None, input_index=None): + super().__init__() + self.ori_name = ori_name self.value = value self.perturbation = perturbation self.from_input = True + self.input_index = input_index def __setattr__(self, key, value): super().__setattr__(key, value) @@ -103,7 +106,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): prefix (str): the prefix for parameters and buffers used in this module """ - for name, param in self._parameters.items(): + for param in self._parameters.values(): if param is not None: if len(prefix.split('.')) == 2: destination[self.ori_name] = param if keep_vars else param.detach() @@ -111,7 +114,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): # change parameters' name to self.ori_name when calling state_dict() destination[ '.'.join(prefix.split('.')[:-2]) + '.' + self.ori_name] = param if keep_vars else param.detach() - for name, buf in self._buffers.items(): + for buf in self._buffers.values(): if buf is not None: if len(prefix.split('.')) == 2: destination[self.ori_name] = buf if keep_vars else buf.detach() @@ -120,7 +123,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination[ '.'.join(prefix.split('.')[:-2]) + '.' + self.ori_name] = buf if keep_vars else buf.detach() - @Bound.save_io_shape def forward(self): return self.value @@ -144,17 +146,15 @@ def infer_batch_dim(self, batch_size, *x): class BoundParams(BoundInput): - def __init__(self, input_name, name, ori_name, value, perturbation=None): - super().__init__(input_name, name, ori_name, None, perturbation) + def __init__(self, ori_name, value, perturbation=None): + super().__init__(ori_name, None, perturbation) self.register_parameter('param', value) self.from_input = False self.initializing = False - """Override register_parameter() hook to register only needed parameters.""" - def register_parameter(self, name, param): + """Override register_parameter() hook to register only needed parameters.""" if name == 'param': - # self._parameters[name] = param # cannot contain '.' in name, it will cause error when loading state_dict return super().register_parameter(name, param) else: # Just register it as a normal property of class. @@ -163,22 +163,20 @@ def register_parameter(self, name, param): def init(self, initializing=False): self.initializing = initializing - @Bound.save_io_shape def forward(self): if self.initializing: - return self.param_init + return self.param_init.requires_grad_(self.training) else: - return self.param + return self.param.requires_grad_(self.training) def infer_batch_dim(self, batch_size, *x): return -1 class BoundBuffers(BoundInput): - def __init__(self, input_name, name, ori_name, value, perturbation=None): - super().__init__(input_name, name, ori_name, None, perturbation) + def __init__(self, ori_name, value, perturbation=None): + super().__init__(ori_name, None, perturbation) self.register_buffer('buffer', value.clone().detach()) - @Bound.save_io_shape def forward(self): - return self.buffer \ No newline at end of file + return self.buffer diff --git a/auto_LiRPA/operators/linear.py b/auto_LiRPA/operators/linear.py index 7aa8ed9..8ce3a8b 100644 --- a/auto_LiRPA/operators/linear.py +++ b/auto_LiRPA/operators/linear.py @@ -1,9 +1,13 @@ """ Linear (possibly with weight perturbation) or Dot product layers """ from .base import * from .bivariate import BoundMul +from ..patches import Patches +from .solver_utils import grb +from torch import Tensor +from ..patches import inplace_unfold class BoundLinear(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): + def __init__(self, attr, inputs, output_index, options): # Gemm: # A = A if transA == 0 else A.T # B = B if transB == 0 else B.T @@ -11,7 +15,7 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio # Y = alpha * np.dot(A, B) + beta * C # return Y - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + super().__init__(attr, inputs, output_index, options) # Defaults in ONNX self.transA = 0 @@ -28,19 +32,18 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio """Handle tranpose and linear coefficients.""" def _preprocess(self, a, b, c=None): - if self.transA and isinstance(a, torch.Tensor): + if self.transA and isinstance(a, Tensor): a = a.transpose(-2,-1) if self.alpha != 1.0: a = self.alpha * a - if not self.transB and isinstance(b, torch.Tensor): + if not self.transB and isinstance(b, Tensor): # our code assumes B is transposed (common case), so we transpose B only when it is not transposed in gemm. - b = b.transpose(-2,-1) + b = b.transpose(-2, -1) if c is not None: if self.beta != 1.0: c = self.beta * c return a, b, c - @Bound.save_io_shape def forward(self, x, w, b=None): x, w, b = self._preprocess(x, w, b) self.input_shape = self.x_shape = x.shape @@ -61,8 +64,8 @@ def onehot_mult(self, weight, bias, C, batch_size): if C.index.ndim == 2: # Shape is [spec, batch] - index = C.index.transpose(0,1) - coeffs = C.coeffs.transpose(0,1) + index = C.index.transpose(0, 1) + coeffs = C.coeffs.transpose(0, 1) else: index = C.index coeffs = C.coeffs @@ -90,13 +93,12 @@ def onehot_mult(self, weight, bias, C, batch_size): new_bias = new_bias.transpose(0, 1) return new_weight, new_bias - def bound_backward(self, last_lA, last_uA, *x): assert len(x) == 2 or len(x) == 3 has_bias = len(x) == 3 # x[0]: input node, x[1]: weight, x[2]: bias - input_lb = [xi.lower if hasattr(xi, 'lower') else None for xi in x] - input_ub = [xi.upper if hasattr(xi, 'upper') else None for xi in x] + input_lb = [getattr(xi, 'lower', None) for xi in x] + input_ub = [getattr(xi, 'upper', None) for xi in x] # transpose and scale each term if necessary. input_lb = self._preprocess(*input_lb) input_ub = self._preprocess(*input_ub) @@ -106,31 +108,83 @@ def bound_backward(self, last_lA, last_uA, *x): # Case #1: No weight/bias perturbation, only perturbation on input. if not self.is_input_perturbed(1) and (not has_bias or not self.is_input_perturbed(2)): + weight = input_lb[1] + bias = input_lb[2] if has_bias else None # If last_lA and last_uA are indentity matrices. - if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC): + if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC): # FIXME (12/28): we should check last_lA and last_uA separately. Same applies to the weight perturbed, bias perturbed settings. # Use this layer's W as the next bound matrices. Duplicate the batch dimension. Other dimensions are kept 1. # Not perturbed, so we can use either lower or upper. - lA_x = uA_x = input_lb[1].unsqueeze(1).repeat([1, batch_size] + [1] * (input_lb[1].ndim - 1)) - # Bias will be directly added to output. + assert last_lA.shape == last_uA.shape + shape_others = prod(last_lA.shape[2:-1]) + A_identity = torch.eye(shape_others).to(weight).view(shape_others, 1, 1, shape_others, 1) + assert last_lA.shape[0] == weight.size(0) * shape_others + w = weight.view(1, weight.size(0), *[1] * (len(last_lA.shape) - 2), weight.size(1)) + w = w * A_identity + + # expand the batch_size dim + lA_x = uA_x = w.view(last_lA.shape[0], 1, *last_lA.shape[2:-1], weight.size(1)).expand(last_lA.shape[0], *last_lA.shape[1:-1], weight.size(1)) if has_bias: - lbias = ubias = input_lb[2].unsqueeze(1).repeat(1, batch_size) + lbias = ubias = bias.unsqueeze(1).repeat(1, batch_size) elif isinstance(last_lA, OneHotC) or isinstance(last_uA, OneHotC): # We need to select several rows from the weight matrix (its shape is output_size * input_size). - lA_x, lbias = self.onehot_mult(input_lb[1], input_lb[2] if has_bias else None, last_lA, batch_size) + lA_x, lbias = self.onehot_mult(weight, bias, last_lA, batch_size) if last_lA is last_uA: uA_x = lA_x ubias = lbias else: - uA_x, ubias = self.onehot_mult(input_lb[1], input_lb[2] if has_bias else None, last_uA, batch_size) + uA_x, ubias = self.onehot_mult(weight, bias, last_uA, batch_size) else: def _bound_oneside(last_A): if last_A is None: return None, 0 - # Just multiply this layer's weight into bound matrices, and produce biases. - next_A = last_A.to(input_lb[1]).matmul(input_lb[1]) - sum_bias = (last_A.to(input_lb[2]).matmul(input_lb[2]) - if has_bias else 0.0) - return next_A, sum_bias + if isinstance(last_A, torch.Tensor): + # Matrix mode. + # Just multiply this layer's weight into bound matrices, and produce biases. + next_A = last_A.to(weight).matmul(weight) + sum_bias = (last_A.to(bias).matmul(bias) + if has_bias else 0.0) + elif isinstance(last_A, Patches): + # Patches mode. After propagating through this layer, it will become a matrix. + # Reshape the weight matrix as a conv image. + # Weight was in (linear_output_shape, linear_input_shape) + # Reshape it to (linear_input_shape, c, h, w) + reshaped_weight = weight.transpose(0,1).view(-1, *last_A.input_shape[1:]) + # After unfolding the shape is (linear_input_shape, output_h, output_w, in_c, patch_h, patch_w) + unfolded_weight = inplace_unfold( + reshaped_weight, + kernel_size=last_A.patches.shape[-2:], + stride=last_A.stride, padding=last_A.padding, + inserted_zeros=last_A.inserted_zeros, + output_padding=last_A.output_padding) + if has_bias: + # Do the same for the bias. + reshaped_bias = bias.view(*last_A.input_shape[1:]).unsqueeze(0) + # After unfolding the bias shape is (1, output_h, output_w, in_c, patch_h, patch_w) + unfolded_bias = inplace_unfold(reshaped_bias, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding) + if last_A.unstable_idx is not None: + # Reshape our weight to (output_h, output_w, 1, in_c, patch_h, patch_w, linear_input_shape), 1 is the inserted batch dim. + unfolded_weight_r = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(2) + # for sparse patches the shape is (unstable_size, batch, in_c, patch_h, patch_w). Batch size is 1 so no need to select here. + # We select in the (output_h, out_w) dimension. + selected_weight = unfolded_weight_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] + next_A = torch.einsum('sbchw,sbchwi->sbi', last_A.patches, selected_weight) + if has_bias: + # Reshape our bias to (output_h, output_w, 1, in_c, patch_h, patch_w). We already have the batch dim. + unfolded_bias_r = unfolded_bias.permute(1, 2, 0, 3, 4, 5) + selected_bias = unfolded_bias_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] + sum_bias = torch.einsum('sbchw,sbchw->sb', last_A.patches, selected_bias) + else: + # Reshape our weight to (1, 1, output_h, output_w, in_c, patch_h, patch_w, linear_input_shape), 1 is the spec and batch. + selected_weight = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(0).unsqueeze(0) + next_A_r = torch.einsum('sbpqchw,sbpqchwi->spqbi', last_A.patches, selected_weight) + # We return a matrix with flattened spec dimension (corresponding to out_c * out_h * out_w). + next_A = next_A_r.reshape(-1, next_A_r.size(-2), next_A_r.size(-1)) + if has_bias: + # Reshape our bias to (1, 1, output_h, output_w, in_c, patch_h, patch_w) + selected_bias = unfolded_bias.unsqueeze(0) + sum_bias_r = torch.einsum('sbpqchw,sbpqchw->spqb', last_A.patches, selected_bias) + sum_bias = sum_bias_r.reshape(-1, sum_bias_r.size(-1)) + return next_A, sum_bias if has_bias else 0.0 lA_x, lbias = _bound_oneside(last_lA) uA_x, ubias = _bound_oneside(last_uA) @@ -176,30 +230,32 @@ def _bound_oneside(last_A): def _reshape(self, x_l, x_u, y_l, y_u): x_shape, y_shape = self.input_shape, self.y_shape - # (x_1, x_2, ..., x_{n-1}, y_2, x_n) + # (x_1, x_2, ..., x_{n-1}, -1, x_n) # FIXME x_l = x_l.unsqueeze(-2) x_u = x_u.unsqueeze(-2) + # FIXME merge these two cases if len(x_shape) == len(y_shape): - # (x_1, x_2, ..., x_{n-1}, y_n, y_{n-1}) - shape = x_shape[:-1] + (y_shape[-1], y_shape[-2]) + # (x_1, x_2, ..., -1, y_n, y_{n-1}) y_l = y_l.unsqueeze(-3) y_u = y_u.unsqueeze(-3) elif len(y_shape) == 2: - # (x_1, x_2, ..., x_{n-1}, y_2, y_1) - shape = x_shape[:-1] + y_shape[1:] + y_shape[:1] + # (x_1, x_2, ..., -1, y_2, y_1) y_l = y_l.reshape(*([1] * (len(x_shape) - 2)), *y_shape).unsqueeze(-3) y_u = y_u.reshape(*([1] * (len(x_shape) - 2)), *y_shape).unsqueeze(-3) + else: + raise ValueError(f'Unsupported shapes: x_shape {x_shape}, y_shape {y_shape}') + return x_l, x_u, y_l, y_u def _relax(self, input_lb, input_ub): return BoundMul.get_bound_mul(*self._reshape(input_lb[0], input_ub[0], input_lb[1], input_ub[1])) + # FIXME This is nonlinear. Move to `bivariate.py`. def bound_backward_with_weight(self, last_lA, last_uA, input_lb, input_ub, x, y): # Note: x and y are not tranposed or scaled, and we should avoid using them directly. # Use input_lb and input_ub instead. alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u = self._relax(input_lb, input_ub) - alpha_l, alpha_u = alpha_l.unsqueeze(0), alpha_u.unsqueeze(0) beta_l, beta_u = beta_l.unsqueeze(0), beta_u.unsqueeze(0) x_shape, y_shape = input_lb[0].size(), input_lb[1].size() @@ -216,34 +272,33 @@ def bound_backward_with_weight(self, last_lA, last_uA, input_lb, input_ub, x, y) def _bound_oneside(last_A, alpha_pos, beta_pos, gamma_pos, alpha_neg, beta_neg, gamma_neg): if last_A is None: return None, None, 0 - if isinstance(last_A, eyeC): - A_x = alpha_pos.squeeze(0).permute(1, 0, 2).repeat(1, last_A.shape[1], 1) - A_y = beta_pos * torch.eye(last_A.shape[2], device=last_A.device) \ - .view((last_A.shape[2], 1, last_A.shape[2], 1)) - if len(dim_y) != 0: - A_y = torch.sum(beta_pos, dim=dim_y) - bias = gamma_pos.transpose(0, 1) - else: - # last_uA has size (batch, spec, output) - last_A_pos = last_A.clamp(min=0).unsqueeze(-1) - last_A_neg = last_A.clamp(max=0).unsqueeze(-1) - # alpha_u has size (batch, spec, output, input) - # uA_x has size (batch, spec, input). - A_x = (alpha_pos.transpose(-1, -2).matmul(last_A_pos) + \ - alpha_neg.transpose(-1, -2).matmul(last_A_neg)).squeeze(-1) - # beta_u has size (batch, spec, output, input) - # uA_y is for weight matrix, with parameter size (output, input) - # uA_y has size (batch, spec, output, input). This is an element-wise multiplication. - A_y = last_A_pos * beta_pos + last_A_neg * beta_neg - if len(dim_y) != 0: - A_y = torch.sum(A_y, dim=dim_y) - # last_uA has size (batch, spec, output) - _last_A_pos = last_A_pos.reshape(last_A.shape[0], last_A.shape[1], -1) - _last_A_neg = last_A_neg.reshape(last_A.shape[0], last_A.shape[1], -1) - # gamma_u has size (batch, output, 1) - # ubias has size (batch, spec, 1) - bias = _last_A_pos.transpose(0, 1).matmul(gamma_pos).transpose(0, 1) + \ - _last_A_neg.transpose(0, 1).matmul(gamma_neg).transpose(0, 1) + if isinstance(last_A, eyeC): # FIXME (12/28): Handle the OneHotC case. + #FIXME previous implementation is incorrect + # expanding eyeC for now + last_A = (torch.eye(last_A.shape[0], device=last_A.device) + .view(last_A.shape[0], 1, *last_A.shape[2:]).expand(last_A.shape)) + + # last_uA has size (batch, spec, output) + last_A_pos = last_A.clamp(min=0).unsqueeze(-1) + last_A_neg = last_A.clamp(max=0).unsqueeze(-1) + # alpha_u has size (batch, spec, output, input) + # uA_x has size (batch, spec, input). + A_x = (alpha_pos.transpose(-1, -2).matmul(last_A_pos) + \ + alpha_neg.transpose(-1, -2).matmul(last_A_neg)).squeeze(-1) + # beta_u has size (batch, spec, output, input) + # uA_y is for weight matrix, with parameter size (output, input) + # uA_y has size (batch, spec, output, input). This is an element-wise multiplication. + A_y = last_A_pos * beta_pos + last_A_neg * beta_neg + if len(dim_y) != 0: + A_y = torch.sum(A_y, dim=dim_y) + # last_uA has size (batch, spec, output) + _last_A_pos = last_A_pos.reshape(last_A.shape[0], last_A.shape[1], -1) + _last_A_neg = last_A_neg.reshape(last_A.shape[0], last_A.shape[1], -1) + # gamma_u has size (batch, output, 1) + # ubias has size (batch, spec, 1) + bias = _last_A_pos.transpose(0, 1).matmul(gamma_pos).transpose(0, 1) + \ + _last_A_neg.transpose(0, 1).matmul(gamma_neg).transpose(0, 1) + bias = bias.squeeze(-1) return A_x, A_y, bias @@ -254,24 +309,6 @@ def _bound_oneside(last_A, alpha_pos, beta_pos, gamma_pos, alpha_neg, beta_neg, @staticmethod def _propagate_Linf(x, w): - if Interval.use_relative_bounds(x): - if len(x.nominal.shape) == 2 and w.ndim == 3: - nominal = torch.bmm(x.nominal.unsqueeze(1), w.transpose(-1, -2)).squeeze(1) - lower_offset = ( - torch.bmm(x.lower_offset.unsqueeze(1), w.clamp(min=0).transpose(-1, -2)) + - torch.bmm(x.upper_offset.unsqueeze(1), w.clamp(max=0).transpose(-1, -2))).squeeze(1) - upper_offset = ( - torch.bmm(x.lower_offset.unsqueeze(1), w.clamp(max=0).transpose(-1, -2)) + - torch.bmm(x.upper_offset.unsqueeze(1), w.clamp(min=0).transpose(-1, -2))).squeeze(1) - else: - nominal = x.nominal.matmul(w.transpose(-1, -2)) - lower_offset = ( - x.lower_offset.matmul(w.clamp(min=0).transpose(-1, -2)) + - x.upper_offset.matmul(w.clamp(max=0).transpose(-1, -2))) - upper_offset = ( - x.lower_offset.matmul(w.clamp(max=0).transpose(-1, -2)) + - x.upper_offset.matmul(w.clamp(min=0).transpose(-1, -2))) - return Interval(None, None, nominal, lower_offset, upper_offset) h_L, h_U = x mid = (h_L + h_U) / 2 diff = (h_U - h_L) / 2 @@ -288,18 +325,11 @@ def interval_propagate(self, *v, C=None, w=None): has_bias = self is not None and len(v) == 3 if self is not None: # This will convert an Interval object to tuple. We need to add perturbation property later. - if Interval.use_relative_bounds(v[0]): - v_nominal = self._preprocess(v[0].nominal, v[1].nominal, v[2].nominal) - v_lower_offset = self._preprocess(v[0].lower_offset, v[1].lower_offset, v[2].lower_offset) - v_upper_offset = self._preprocess(v[0].upper_offset, v[1].upper_offset, v[2].upper_offset) - v = [Interval(None, None, bounds[0], bounds[1], bounds[2]) - for bounds in zip(v_nominal, v_lower_offset, v_upper_offset)] - else: - v_lb, v_ub = zip(*v) - v_lb = self._preprocess(*v_lb) - v_ub = self._preprocess(*v_ub) - # After preprocess the lower and upper bounds, we make them Intervals again. - v = [Interval.make_interval(bounds[0], bounds[1], bounds[2]) for bounds in zip(v_lb, v_ub, v)] + v_lb, v_ub = zip(*v) + v_lb = self._preprocess(*v_lb) + v_ub = self._preprocess(*v_ub) + # After preprocess the lower and upper bounds, we make them Intervals again. + v = [Interval.make_interval(bounds[0], bounds[1], bounds[2]) for bounds in zip(v_lb, v_ub, v)] if w is None and self is None: # Use C as the weight, no bias. w, lb, ub = C, torch.tensor(0., device=C.device), torch.tensor(0., device=C.device) @@ -311,25 +341,16 @@ def interval_propagate(self, *v, C=None, w=None): # C matrix merging not supported. assert C is None res = self.interval_propagate_with_weight(*v) - if Interval.use_relative_bounds(res): - if has_bias: - raise NotImplementedError - else: - return res + l, u = res + if has_bias: + return l + v[2][0], u + v[2][1] else: - l, u = res - if has_bias: - return l + v[2][0], u + v[2][1] - else: - return l, u + return l, u else: - # Use weight - if Interval.use_relative_bounds(v[1]): - w = v[1].nominal - else: - w = v[1][0] + # Use weight + w = v[1][0] if has_bias: - lb, ub = (v[2].lower, v[2].upper) if Interval.use_relative_bounds(v[2]) else v[2] + lb, ub = v[2] else: lb = ub = 0.0 @@ -342,14 +363,7 @@ def interval_propagate(self, *v, C=None, w=None): norm, eps = Interval.get_perturbation(v[0])[:2] if norm == np.inf: interval = BoundLinear._propagate_Linf(v[0], w) - if isinstance(interval, Interval): - b_center = (lb + ub) / 2 - interval.nominal += b_center - interval.lower_offset += lb - b_center - interval.upper_offset += ub - b_center - return interval - else: - center, deviation = interval + center, deviation = interval elif norm > 0: # General Lp norm. norm, eps = Interval.get_perturbation(v[0]) @@ -384,36 +398,7 @@ def interval_propagate(self, *v, C=None, w=None): def interval_propagate_with_weight(self, *v): input_norm, input_eps = Interval.get_perturbation(v[0]) - weight_norm, weight_eps = Interval.get_perturbation(v[1]) - - if Interval.use_relative_bounds(*v): - assert input_norm == weight_norm == np.inf - assert self.opt_matmul == 'economic' - - x, y = v[0], v[1] - - nominal = x.nominal.matmul(y.nominal.transpose(-1, -2)) - - matmul_offset = torch.matmul( - torch.max(x.lower_offset.abs(), x.upper_offset.abs()), - torch.max(y.upper_offset.abs(), y.lower_offset.abs()).transpose(-1, -2)) - - lower_offset = ( - x.nominal.clamp(min=0).matmul(y.lower_offset.transpose(-1, -2)) + - x.nominal.clamp(max=0).matmul(y.upper_offset.transpose(-1, -2)) + - x.lower_offset.matmul(y.nominal.clamp(min=0).transpose(-1, -2)) + - x.upper_offset.matmul(y.nominal.clamp(max=0).transpose(-1, -2)) - matmul_offset) - - upper_offset = ( - x.nominal.clamp(min=0).matmul(y.upper_offset.transpose(-1, -2)) + - x.nominal.clamp(max=0).matmul(y.lower_offset.transpose(-1, -2)) + - x.upper_offset.matmul(y.nominal.clamp(min=0).transpose(-1, -2)) + - x.lower_offset.matmul(y.nominal.clamp(max=0).transpose(-1, -2)) + matmul_offset) - - return Interval(None, None, nominal, lower_offset, upper_offset) - - self.x_shape = v[0][0].shape - self.y_shape = v[1][0].shape + weight_norm, weight_eps = Interval.get_perturbation(v[1]) if input_norm == np.inf and weight_norm == np.inf: # A memory-efficient implementation without expanding all the elementary multiplications @@ -424,9 +409,9 @@ def interval_propagate_with_weight(self, *v): dx, dy = F.relu(x_u - x_l), F.relu(y_u - y_l) base = x_l.matmul(y_l) - mask_xp, mask_xn = (x_l > 0).float(), (x_u < 0).float() + mask_xp, mask_xn = (x_l > 0).to(x_l.dtype), (x_u < 0).to(x_u.dtype) mask_xpn = 1 - mask_xp - mask_xn - mask_yp, mask_yn = (y_l > 0).float(), (y_u < 0).float() + mask_yp, mask_yn = (y_l > 0).to(y_l.dtype), (y_u < 0).to(y_u.dtype) mask_ypn = 1 - mask_yp - mask_yn lower, upper = base.clone(), base.clone() @@ -445,7 +430,7 @@ def interval_propagate_with_weight(self, *v): x_l, x_u = v[0][0].unsqueeze(-2), v[0][1].unsqueeze(-2) y_l, y_u = v[1][0].unsqueeze(-3), v[1][1].unsqueeze(-3) # Reuse the multiplication bounds and sum over results. - lower, upper = BoundMul.interval_propagate(*[(x_l, x_u), (y_l, y_u)]) + lower, upper = BoundMul.interval_propagate_both_perturbed(*[(x_l, x_u), (y_l, y_u)]) lower, upper = torch.sum(lower, -1), torch.sum(upper, -1) return lower, upper @@ -465,9 +450,46 @@ def interval_propagate_with_weight(self, *v): raise NotImplementedError( "Unsupported perturbation combination: data={}, weight={}".format(input_norm, weight_norm)) + @staticmethod + @torch.jit.script + def bound_forward_mul(x_lw: Tensor, x_lb: Tensor, x_uw: Tensor, x_ub: Tensor, w: Tensor): + w_pos, w_neg = w.clamp(min=0), w.clamp(max=0) + lw = x_lw.matmul(w_pos) + x_uw.matmul(w_neg) + uw = x_uw.matmul(w_pos) + x_lw.matmul(w_neg) + lb = x_lb.matmul(w_pos) + x_ub.matmul(w_neg) + ub = x_ub.matmul(w_pos) + x_lb.matmul(w_neg) + return lw, lb, uw, ub + + # w: an optional argument which can be utilized by BoundMatMul + def bound_dynamic_forward(self, x, w=None, b=None, C=None, max_dim=None, offset=0): + assert not self.transA and self.alpha == 1.0 and self.transB and self.beta == 1.0 + assert not self.is_input_perturbed(1) + assert not self.is_input_perturbed(2) + + weight = w.lb + bias = b.lb if b is not None else None + if C is not None: + weight = C.to(weight).matmul(weight).transpose(-1, -2) + if bias is not None: + bias = C.to(bias).matmul(bias) + lb = x.lb.unsqueeze(1) + else: + weight = weight.t() + lb = x.lb + + w_new = x.lw.matmul(weight) + b_new = lb.matmul(weight) + if C is not None: + b_new = b_new.squeeze(1) + if bias is not None: + b_new += bias + + return LinearBound(w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim) + # w: an optional argument which can be utilized by BoundMatMul def bound_forward(self, dim_in, x, w=None, b=None, C=None): has_bias = b is not None + #FIXME _preprocess can only be applied to tensors so far but not linear bounds. x, w, b = self._preprocess(x, w, b) # Case #1: No weight/bias perturbation, only perturbation on input. @@ -480,16 +502,15 @@ def bound_forward(self, dim_in, x, w=None, b=None, C=None): w = C.to(w).matmul(w).transpose(-1, -2) if b is not None: b = C.to(b).matmul(b) - w_pos, w_neg = w.clamp(min=0), w.clamp(max=0) - lb = (x.lb.unsqueeze(1).matmul(w_pos) + x.ub.unsqueeze(1).matmul(w_neg)).squeeze(1) - ub = (x.ub.unsqueeze(1).matmul(w_pos) + x.lb.unsqueeze(1).matmul(w_neg)).squeeze(1) - else: + x_lb, x_ub = x.lb.unsqueeze(1), x.ub.unsqueeze(1) + else: w = w.t() - w_pos, w_neg = w.clamp(min=0), w.clamp(max=0) - lb = x.lb.matmul(w_pos) + x.ub.matmul(w_neg) - ub = x.ub.matmul(w_pos) + x.lb.matmul(w_neg) - lw = x.lw.matmul(w_pos) + x.uw.matmul(w_neg) - uw = x.uw.matmul(w_pos) + x.lw.matmul(w_neg) + x_lb, x_ub = x.lb, x.ub + lw, lb, uw, ub = BoundLinear.bound_forward_mul(x.lw, x_lb, x.uw, x_ub, w) + + if C is not None: + lb, ub = lb.squeeze(1), ub.squeeze(1) + if b is not None: lb += b ub += b @@ -524,7 +545,7 @@ def bound_forward_with_weight(self, dim_in, x, y): y.lower.unsqueeze(-3), y.upper.unsqueeze(-3), ) - res_mul = BoundMul.bound_forward(dim_in, x_unsqueeze, y_unsqueeze) + res_mul = BoundMul.bound_forward_both_perturbed(dim_in, x_unsqueeze, y_unsqueeze) return LinearBound( res_mul.lw.sum(dim=-1) if res_mul.lw is not None else None, res_mul.lb.sum(dim=-1), @@ -532,16 +553,76 @@ def bound_forward_with_weight(self, dim_in, x, y): res_mul.ub.sum(dim=-1) ) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + has_bias = self is not None and len(v) == 3 + # e.g., last layer gurobi vars (1024,) + gvars_array = np.array(v[0]) + # pre_layer_shape (1024,) + pre_layer_shape = gvars_array.shape + # this layer shape (100,) + # if last layer, this layer shape (9,) instead of (10,)!!! + this_layer_shape = self.lower.squeeze(0).shape + out_lbs = self.lower.squeeze(0).detach().cpu().numpy() if self.lower is not None else None + out_ubs = self.upper.squeeze(0).detach().cpu().numpy() if self.upper is not None else None + + # current layer weight (100, 1024) + this_layer_weight = v[1] + #### make sure if this is correct for per-label operations + if C is not None: + # merge specification C into last layer weights + # only last layer has C not None + this_layer_weight = C.squeeze(0).mm(this_layer_weight) + # if last layer, this layer weight (9,100) instead of (10,100)!!! + this_layer_weight = this_layer_weight.detach().cpu().numpy() + + this_layer_bias = None + if has_bias: + # current layer bias (100,) + this_layer_bias = v[2] + if C is not None: + this_layer_bias = C.squeeze(0).mm(this_layer_bias.unsqueeze(-1)).view(-1) + # if last layer, this layer bias (9,) instead of (10,)!!! + this_layer_bias = this_layer_bias.detach().cpu().numpy() + + new_layer_gurobi_vars = [] + + for neuron_idx in range(this_layer_shape[0]): + out_lb = out_lbs[neuron_idx] if out_lbs is not None else -float('inf') + out_ub = out_ubs[neuron_idx] if out_ubs is not None else float('inf') + + lin_expr = 0 + if has_bias: + lin_expr = this_layer_bias[neuron_idx].item() + coeffs = this_layer_weight[neuron_idx, :] + + if solver_pkg == 'gurobi': + lin_expr += grb.LinExpr(coeffs, v[0]) + else: + # FIXME (01/12/22): This is slow, must be fixed using addRow() or similar. + for i in range(len(coeffs)): + try: + lin_expr += coeffs[i] * v[0][i] + except TypeError: + lin_expr += coeffs[i] * v[0][i].var + + var = model.addVar(lb=out_lb, ub=out_ub, obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'lay{self.name}_{neuron_idx}') + model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq') + new_layer_gurobi_vars.append(var) + + self.solver_vars = new_layer_gurobi_vars + model.update() + class BoundMatMul(BoundLinear): # Reuse most functions from BoundLinear. - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.transA = 0 - self.transB = 1 # MatMul assumes B is transposed. - self.nonlinear = True + self.transB = 0 + self.requires_input_bounds = [0, 1] - @Bound.save_io_shape def forward(self, x, y): self.x_shape = x.shape self.y_shape = y.shape @@ -550,23 +631,14 @@ def forward(self, x, y): return x.matmul(y) def interval_propagate(self, *v): - w_l = v[1][0].transpose(-1, -2) - w_u = v[1][1].transpose(-1, -2) - lower, upper = super().interval_propagate(v[0], (w_l, w_u)) - return lower, upper + lower, upper = super().interval_propagate(*v) + return lower, upper def bound_backward(self, last_lA, last_uA, *x): assert len(x) == 2 - # BoundLinear has W transposed. - x[1].lower = x[1].lower.transpose(-1, -2) - x[1].upper = x[1].upper.transpose(-1, -2) results = super().bound_backward(last_lA, last_uA, *x) - # Transpose input back. - x[1].lower = x[1].lower.transpose(-1, -2) - x[1].upper = x[1].upper.transpose(-1, -2) lA_y = results[0][1][0].transpose(-1, -2) if results[0][1][0] is not None else None uA_y = results[0][1][1].transpose(-1, -2) if results[0][1][1] is not None else None - # Transpose result on A. return [results[0][0], (lA_y, uA_y), results[0][2]], results[1], results[2] def bound_forward(self, dim_in, x, y): @@ -579,37 +651,62 @@ def bound_forward(self, dim_in, x, y): y.upper.transpose(-1, -2) if y.upper is not None else None )) - def infer_batch_dim(self, batch_size, *x): - return BoundMul.infer_batch_dim(batch_size, *x) - class BoundNeg(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) - @Bound.save_io_shape def forward(self, x): return -x def bound_backward(self, last_lA, last_uA, x): - return [(-last_lA if last_lA is not None else None, + if type(last_lA) == Tensor or type(last_uA) == Tensor: + return [(-last_lA if last_lA is not None else None, -last_uA if last_uA is not None else None)], 0, 0 + elif type(last_lA) == Patches or type(last_uA) == Patches: + if last_lA is not None: + lA = Patches(-last_lA.patches, last_lA.stride, last_lA.padding, last_lA.shape, unstable_idx=last_lA.unstable_idx, output_shape=last_lA.output_shape) + else: + lA = None + + if last_uA is not None: + uA = Patches(-last_uA.patches, last_uA.stride, last_uA.padding, last_uA.shape, unstable_idx=last_uA.unstable_idx, output_shape=last_uA.output_shape) + else: + uA = None + return [(lA, uA)], 0, 0 + else: + raise NotImplementedError def bound_forward(self, dim_in, x): return LinearBound(-x.uw, -x.ub, -x.lw, -x.lb) def interval_propagate(self, *v): - return -v[0][1], -v[0][0] + return -v[0][1], -v[0][0] + class BoundCumSum(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.use_default_ibp = True - @Bound.save_io_shape def forward(self, x, axis): self.axis = axis return torch.cumsum(x, axis) def infer_batch_dim(self, batch_size, *x): assert self.axis != x[0] - return x[0] \ No newline at end of file + return x[0] + + +class BoundIdentity(Bound): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.use_default_ibp = True + + def forward(self, x): + return x + + def bound_backward(self, last_lA, last_uA, x): + return [(last_lA, last_uA)], 0, 0 + + def bound_forward(self, dim_in, x): + return x diff --git a/auto_LiRPA/operators/logical.py b/auto_LiRPA/operators/logical.py index 869a50a..179a7ac 100644 --- a/auto_LiRPA/operators/logical.py +++ b/auto_LiRPA/operators/logical.py @@ -3,24 +3,14 @@ class BoundWhere(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) - @Bound.save_io_shape def forward(self, condition, x, y): return torch.where(condition.to(torch.bool), x, y) def interval_propagate(self, *v): assert not self.is_input_perturbed(0) - - if Interval.use_relative_bounds(*v): - return Interval( - None, None, - self.forward(v[0].nominal, v[1].nominal, v[2].nominal), - self.forward(v[0].nominal, v[1].lower_offset, v[2].lower_offset), - self.forward(v[0].nominal, v[1].upper_offset, v[2].upper_offset) - ) - condition = v[0][0] return tuple([torch.where(condition, v[1][j], v[2][j]) for j in range(2)]) @@ -42,14 +32,9 @@ def _bound_oneside(last_A): return [(None, None), (lA_x, uA_x), (lA_y, uA_y)], 0, 0 - def infer_batch_dim(self, batch_size, *x): - return BoundMul.infer_batch_dim(batch_size, *x[1:]) - - class BoundNot(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) - @Bound.save_io_shape def forward(self, x): return x.logical_not() \ No newline at end of file diff --git a/auto_LiRPA/operators/nonlinear.py b/auto_LiRPA/operators/nonlinear.py new file mode 100644 index 0000000..4f6bb5c --- /dev/null +++ b/auto_LiRPA/operators/nonlinear.py @@ -0,0 +1,668 @@ +"""Unary nonlinearities other than activation functions.""" +import math +import torch +from .activation import BoundActivation, BoundTanh +from .base import epsilon, LinearBound + +class BoundSin(BoundActivation): + # Lookup tables shared by all BoundSin classes. + xl_lower_tb = None + xl_upper_tb = None + xu_lower_tb = None + xu_upper_tb = None + func, d_func = torch.sin, torch.cos + n_table_entries = 1001 + + @staticmethod + def n_crossing(start, end, s): + """Check how many times we will encounter value s + k*2*pi within start and end for any integer k.""" + dtype = start.dtype + cycles = torch.floor((end - start) / (2 * math.pi)) # Number of 2pi cycles. + # Move s and end to the same 2 * pi cycle as start. + dist = torch.floor((s - start) / (2 * math.pi)) + real_s = s - dist * 2 * math.pi + real_end = end - cycles * 2 * math.pi + # assert (real_end >= start - 2 ** (-20)).all() + return (real_s >= start).to(dtype) * (real_s <= real_end).to(dtype) + cycles + + @staticmethod + def get_intersection(start, end, c, theta=0.): + """Get the number of intersections between y = sin(x + theta) and y = c between start and end.""" + # Use arcsine to find the first 2 intersections. + crossing1 = torch.arcsin(c) - theta + crossing2 = math.pi - crossing1 - 2 * theta # Problematic at exact 1/2 pi, but ok in our case (happens only when lb=ub). + return BoundSin.n_crossing(start, end, crossing1) + BoundSin.n_crossing(start, end, crossing2) + + @staticmethod + def get_bounding_slope(xl, xu, c, theta=0.): + """Find the point between xl and xu such that the tangent line at that point is a lower/upper bound.""" + dtype = xl.dtype + crossing1 = torch.arcsin(c) - theta # output is [-0.5 pi, 0.5 pi] - theta. For cosine, theta=0.5 pi and crossing point is between -pi to 0. + crossing2 = math.pi - crossing1 - 2 * theta # output is [0.5pi, 1.5pi] - theta. For cosine, it is between 0 and pi. + # Find the crossing point between xl and xu. + # First see how xl is away from the [-0.5pi, 1.5pi] range for sine or [-pi, pi] range for cosine. + cycles1 = torch.floor((crossing1 - xl) / (2 * math.pi)) * 2 * math.pi + # Move the two crossing points to the same cycle as xl. + crossing1_moved = crossing1 - cycles1 + cycles2 = torch.floor((crossing2 - xl) / (2 * math.pi)) * 2 * math.pi + crossing2_moved = crossing2 - cycles2 + # Then check which crossing point is the actual tangent point between xl and xu. + crossing1_used = (crossing1_moved >= xl).to(dtype) * (crossing1_moved <= xu).to(dtype) + crossing2_used = (crossing2_moved >= xl).to(dtype) * (crossing2_moved <= xu).to(dtype) + crossing_point = crossing1_used * crossing1_moved + crossing2_used * crossing2_moved + # print(f'c1={crossing1.item():.05f}, c2={crossing2.item():.05f}, cy1={cycles1.item():.05f}, cy2={cycles2.item():.05f}, c1m={crossing1_moved.item():.05f}, c2m={crossing2_moved.item():.05f}, u1={crossing1_used.item()}, u2={crossing2_used.item()}, xl={xl.item():.05f}, xu={xu.item():.05f}') + return crossing_point + + @staticmethod + def check_bound(tangent_point, x): + """Check whether the tangent line at tangent_point is a valid lower/upper bound for x.""" + # evaluate the value of the tangent line at x and see it is >= 0 or <=0. + d = BoundSin.d_func(tangent_point) + val = d * (x - tangent_point) + BoundSin.func(tangent_point) + # We want a positive margin when finding a lower line, but as close to 0 as possible. + # We want a negative margin when finding a upper line, but as close to 0 as possible. + margin = BoundSin.func(x) - val + return margin + + @staticmethod + @torch.no_grad() + def get_lower_left_bound(xl, steps=20): + """Get a global lower bound given lower bound on x. Return slope and intercept.""" + dtype = xl.dtype + # Constrain xl into the -0.5 pi to 1.5 pi region. + cycles = torch.floor((xl + 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi) + xl = xl - cycles + use_tangent_line = (xl >= math.pi).to(dtype) + # Case 1: xl > pi, Lower tangent line is the only possible lower bound. + case1_d = BoundSin.d_func(xl) + case1_b = BoundSin.func(xl) - case1_d * (xl + cycles) + # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region. + left = math.pi * torch.ones_like(xl) + # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1). + right = (1.5 * math.pi) * torch.ones_like(xl) + last_right = right.clone() + for i in range(steps): + mid = (left + right) / 2. + margin = BoundSin.check_bound(mid, xl) + pos_mask = (margin > 0).to(dtype) # We want to margin > 0 but at small as possible. + neg_mask = 1.0 - pos_mask + right = mid * pos_mask + right * neg_mask # We have positive margin, reduce right hand side. + last_right = mid * pos_mask + last_right * neg_mask # Always sound, since the margin is positive. + left = mid * neg_mask + left * pos_mask + case2_d = BoundSin.d_func(last_right) + case2_b = BoundSin.func(last_right) - case2_d * (last_right + cycles) + d = case1_d * use_tangent_line + case2_d * (1. - use_tangent_line) + b = case1_b * use_tangent_line + case2_b * (1. - use_tangent_line) + # Return slope and bias. + return [d, b] + + @staticmethod + @torch.no_grad() + def get_upper_left_bound(xl, steps=20): + dtype = xl.dtype + """Get a global upper bound given lower bound on x. Return slope and intercept.""" + # Constrain xl into the -0.5 pi to 1.5 pi region. + cycles = torch.floor((xl - 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi) + xl = xl - cycles + use_tangent_line = (xl >= 2.0 * math.pi).to(dtype) + # Case 1: xl > pi, Lower tangent line is the only possible lower bound. + case1_d = BoundSin.d_func(xl) + case1_b = BoundSin.func(xl) - case1_d * (xl + cycles) + # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region. + left = (2.0 * math.pi) * torch.ones_like(xl) + # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1). + right = (2.5 * math.pi) * torch.ones_like(xl) + last_right = right.clone() + for i in range(steps): + mid = (left + right) / 2. + margin = BoundSin.check_bound(mid, xl) + pos_mask = (margin > 0).to(dtype) # We want to margin < 0 but at small as possible. + neg_mask = 1.0 - pos_mask + right = mid * neg_mask + right * pos_mask # We have positive margin, reduce right hand side. + last_right = mid * neg_mask + last_right * pos_mask # Always sound, since the margin is positive. + left = mid * pos_mask + left * neg_mask + case2_d = BoundSin.d_func(last_right) + case2_b = BoundSin.func(last_right) - case2_d * (last_right + cycles) + d = case1_d * use_tangent_line + case2_d * (1. - use_tangent_line) + b = case1_b * use_tangent_line + case2_b * (1. - use_tangent_line) + # Return slope and bias. + return [d, b] + + @staticmethod + @torch.no_grad() + def get_lower_right_bound(xu, steps=20): + """Get a global lower bound given upper bound on x. Return slope and intercept.""" + # Constrain xu into the -0.5 pi to 1.5 pi region. + cycles = torch.floor((xu + 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi) + xu = xu - cycles + d, _ = BoundSin.get_lower_left_bound(math.pi - xu, steps) + return [-d, BoundSin.func(xu) + d * (xu + cycles)] + + @staticmethod + @torch.no_grad() + def get_upper_right_bound(xu, steps=20): + """Get a global upper bound given upper bound on x. Return slope and intercept.""" + # Constrain xu into the 0.5 pi to 2.5 pi region. + cycles = torch.floor((xu - 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi) + xu = xu - cycles + d, _ = BoundSin.get_upper_left_bound(3 * math.pi - xu, steps) + return [-d, BoundSin.func(xu) + d * (xu + cycles)] + + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + # Bound limits used by IBP. + self.max_point = math.pi / 2 + self.min_point = math.pi * 3 / 2 + + self.all_table_x = torch.linspace(0, 2 * math.pi, BoundSin.n_table_entries, device=self.device) + if BoundSin.xl_lower_tb is None: + # Generate look-up tables. + BoundSin.xl_lower_tb = BoundSin.get_lower_left_bound(self.all_table_x) + BoundSin.xl_upper_tb = BoundSin.get_upper_left_bound(self.all_table_x) + BoundSin.xu_lower_tb = BoundSin.get_lower_right_bound(self.all_table_x) + BoundSin.xu_upper_tb = BoundSin.get_upper_right_bound(self.all_table_x) + BoundSin.xl_lower_tb[0], BoundSin.xl_lower_tb[1] = BoundSin.xl_lower_tb[0].to(self.device), BoundSin.xl_lower_tb[1].to(self.device) + BoundSin.xl_upper_tb[0], BoundSin.xl_upper_tb[1] = BoundSin.xl_upper_tb[0].to(self.device), BoundSin.xl_upper_tb[1].to(self.device) + BoundSin.xu_lower_tb[0], BoundSin.xu_lower_tb[1] = BoundSin.xu_lower_tb[0].to(self.device), BoundSin.xu_lower_tb[1].to(self.device) + BoundSin.xu_upper_tb[0], BoundSin.xu_upper_tb[1] = BoundSin.xu_upper_tb[0].to(self.device), BoundSin.xu_upper_tb[1].to(self.device) + + @staticmethod + def interpoloate(x, lower_x, upper_x, lower_y, upper_y): + # x = torch.clamp(x, min=lower_x, max=upper_x) # For pytorch >= 1.11 + x = torch.max(torch.min(x, upper_x), lower_x) + ratio = (x - lower_x) / (upper_x - lower_x + 1e-10) + return lower_y * (1. - ratio) + upper_y * ratio + + def get_bound_tb(self, tb, x): + """Find lower or upper bounds from lookup table.""" + step = 2 * math.pi / (BoundSin.n_table_entries - 1) + # Move to 0 to 2 pi region. + cycles = torch.floor(x / (2 * math.pi)) * (2 * math.pi) + x = torch.clamp(x - cycles, min=0, max=2 * math.pi) + # Find the indice within the lookup table from 0 - 2pi. + indices = x.div(step).long() + # Intepolate the nearest d and b. This has better differentiability. + # Another option is to always take the lower/upper side (better soundness). + upper_indices = torch.clamp(indices + 1, max=BoundSin.n_table_entries-1) + lower_x = self.all_table_x[indices] + upper_x = self.all_table_x[upper_indices] + # print(indices.item(), lower_x.item(), upper_x.item(), tb[0][indices].item(), tb[0][upper_indices].item()) + d = self.interpoloate(x, lower_x, upper_x, tb[0][indices], tb[0][upper_indices]) + b = self.interpoloate(x, lower_x, upper_x, tb[1][indices], tb[1][upper_indices]) + return d, b - d * cycles + + def forward(self, x): + return torch.sin(x) + + def interval_propagate(self, *v): + # Check if a point is in [l, u], considering the 2pi period + def check_crossing(ll, uu, point): + return ((((uu - point) / (2 * math.pi)).floor() - ((ll - point) / (2 * math.pi)).floor()) > 0).to(h_Ls.dtype) + h_L, h_U = v[0][0], v[0][1] + h_Ls, h_Us = self.forward(h_L), self.forward(h_U) + # If crossing pi/2, then max is fixed 1.0 + max_mask = check_crossing(h_L, h_U, self.max_point) + # If crossing pi*3/2, then min is fixed -1.0 + min_mask = check_crossing(h_L, h_U, self.min_point) + ub = torch.max(h_Ls, h_Us) + ub = max_mask + (1 - max_mask) * ub + lb = torch.min(h_Ls, h_Us) + lb = - min_mask + (1 - min_mask) * lb + return lb, ub + + def bound_relax_impl(self, lb, ub): + dtype = lb.dtype + # Case 1: Connect the two points as a line + sub = self.func(ub) + slb = self.func(lb) + mid = (sub + slb) / 2. + smid = self.func((ub + lb) / 2) + case1_line_slope = (sub - slb) / (ub - lb + 1e-10) + case1_line_bias = slb - case1_line_slope * lb + gap = smid - mid + # Check if there are crossings between the line and the sin function. + grad_crossings = self.get_intersection(lb, ub, case1_line_slope, theta=0.5 * math.pi) + # If there is no crossing, then we can connect the two points together as a lower/upper bound. + use_line = grad_crossings == 1 + # Connected line is the upper bound. + upper_use_line = torch.logical_and(gap < 0, use_line) + # Connected line is the lower bound. + lower_use_line = torch.logical_and(gap >= 0, use_line) + # For the other bound, use the tangent line. + case1_tangent_point = self.get_bounding_slope(lb, ub, case1_line_slope, theta=0.5 * math.pi) + case1_tangent_slope = case1_line_slope # Use the same slope so far. + stangent = self.func(case1_tangent_point) + case1_tangent_bias = stangent - case1_tangent_slope * case1_tangent_point + # Choose the lower/upper based on gap. + case1_lower_slope = lower_use_line * case1_line_slope + upper_use_line * case1_tangent_slope + case1_lower_bias = lower_use_line * case1_line_bias + upper_use_line * case1_tangent_bias + case1_upper_slope = upper_use_line * case1_line_slope + lower_use_line * case1_tangent_slope + case1_upper_bias = upper_use_line * case1_line_bias + lower_use_line * case1_tangent_bias + + # Case 2: we will try the global lower/upper bounds at lb and ub. + # For the points and lb and ub, we can construct both lower and upper bounds. + left_lower = self.get_bound_tb(BoundSin.xl_lower_tb, lb) # slope, bias. + left_upper = self.get_bound_tb(BoundSin.xl_upper_tb, lb) + right_lower = self.get_bound_tb(BoundSin.xu_lower_tb, ub) + right_upper = self.get_bound_tb(BoundSin.xu_upper_tb, ub) + # Determine which lower bound is tighter. + left_lower_error = sub - (left_lower[0] * ub + left_lower[1]) + right_lower_error = slb - (right_lower[0] * lb + right_lower[1]) + left_upper_error = (left_upper[0] * ub + left_upper[1]) - sub + right_upper_error = (right_upper[0] * lb + right_upper[1]) - slb + use_left_lower = (left_lower_error < right_lower_error).to(dtype) + use_right_lower = 1. - use_left_lower + use_left_upper = (left_upper_error < right_upper_error).to(dtype) + use_right_upper = 1. - use_left_upper + # Choose the slope and bias in this case. + case_2_lower_slope = use_left_lower * left_lower[0] + use_right_lower * right_lower[0] + case_2_lower_bias = use_left_lower * left_lower[1] + use_right_lower * right_lower[1] + case_2_upper_slope = use_left_upper * left_upper[0] + use_right_upper * right_upper[0] + case_2_upper_bias = use_left_upper * left_upper[1] + use_right_upper * right_upper[1] + + # Finally, choose between case 1 and case 2. + use_line = use_line.to(dtype) + not_use_line = 1. - use_line + lower_slope = use_line * case1_lower_slope + not_use_line * case_2_lower_slope + lower_bias = use_line * case1_lower_bias + not_use_line * case_2_lower_bias + upper_slope = use_line * case1_upper_slope + not_use_line * case_2_upper_slope + upper_bias = use_line * case1_upper_bias + not_use_line * case_2_upper_bias + # print(gap, lower_slope, lower_bias, upper_slope, upper_bias) + return lower_slope, lower_bias, upper_slope, upper_bias + + def bound_relax(self, x): + lower_slope, lower_bias, upper_slope, upper_bias = self.bound_relax_impl(x.lower, x.upper) + self.lw = lower_slope + self.lb = lower_bias + self.uw = upper_slope + self.ub = upper_bias + + +class BoundCos(BoundSin): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.max_point = 0.0 + self.min_point = math.pi + + def forward(self, x): + return torch.cos(x) + + def bound_relax(self, x): + # Shift the input by 0.5*pi, and shifting the linear bounds back. + lb = x.lower + 0.5 * math.pi + ub = x.upper + 0.5 * math.pi + lower_slope, lower_bias, upper_slope, upper_bias = self.bound_relax_impl(lb, ub) + self.lw = lower_slope + self.lb = lower_slope * (0.5 * math.pi) + lower_bias + self.uw = upper_slope + self.ub = upper_slope * (0.5 * math.pi) + upper_bias + + +class BoundAtan(BoundTanh): + def __init__(self, attr, inputs, output_index, options): + super(BoundTanh, self).__init__(attr, inputs, output_index, options) + self.precompute_relaxation('arctan', torch.arctan, self.darctan) + # Alpha dimension is (4, 2, output_shape, batch, *shape) for S-shaped functions. + self.alpha_batch_dim = 3 + + def forward(self, x): + return torch.arctan(x) + + def darctan(self, x): + return (x.square() + 1.).reciprocal() + + def bound_relax(self, x): + self.bound_relax_impl(x, torch.arctan, self.darctan) + + +class BoundTan(BoundAtan): + """ + The implementation of BoundTan is based on the S-shaped BoundAtan. We use the bounds from its + inverse function and directly convert the bounds of the inverse function to bounds of the original + function. This trick allows us to quickly implement bounds on inverse functions. + """ + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + + def forward(self, x): + return torch.tan(x) + + def _check_bounds(self, lower, upper): + # Lower and upper bounds must be within the same [-½π, ½π] region. + lower_periods = torch.floor((lower + 0.5 * torch.pi) / torch.pi) + upper_periods = torch.floor((upper + 0.5 * torch.pi) / torch.pi) + if not torch.allclose(lower_periods, upper_periods): + print('Tan preactivation lower bounds:\n', lower) + print('Tan preactivation upper bounds:\n', upper) + raise ValueError("BoundTan received pre-activation bounds that produce infinity. " + "The preactivation bounds are too loose. Try to reduce perturbation region.") + # Return the period number for each neuron. + # Period is 0 => bounds are within [-½π, ½π], + # Period is 1 => bounds are within [-½π + π, ½π + π] + # Period is -1 => bounds are within [-½π - π, ½π - π] + return lower_periods + + def _init_masks(self, x): + # The masks now must consider the periodicity. + lower = torch.remainder(x.lower + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi + upper = torch.remainder(x.upper + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi + self.mask_pos = lower >= 0 + self.mask_neg = upper <= 0 + self.mask_both = torch.logical_not(torch.logical_or(self.mask_pos, self.mask_neg)) + + def interval_propagate(self, *v): + # We need to check if the input lower and upper bounds are within the same period. + # Otherwise the bounds become infinity. + concrete_lower, concrete_upper = v[0][0], v[0][1] + self._check_bounds(concrete_lower, concrete_upper) + return super().interval_propagate(*v) + + def bound_relax(self, x): + periods = self._check_bounds(x.lower, x.upper) + periods = torch.pi * periods + # Create a fake x with inversed lower and upper. + inverse_x = lambda: None + inverse_x.lower = torch.tan(x.lower) + inverse_x.upper = torch.tan(x.upper) + super().bound_relax(inverse_x) + # Lower slope, lower bias, upper slope and upper bias are saved to + # self.lw, self.lb, self.uw, self.ub. We need to reverse them. + # E.g., y = self.lw * x + self.lb, now becomes x = 1./self.lw * y - self.lb / self.lw + # Additionally, we need to add the missing ½π periods. + new_upper_slope = 1. / self.lw + new_upper_bias = - self.lb / self.lw - periods / self.lw + new_lower_slope = 1. / self.uw + new_lower_bias = - self.ub / self.uw - periods / self.uw + self.lw = new_lower_slope + self.lb = new_lower_bias + self.uw = new_upper_slope + self.ub = new_upper_bias + + +class BoundExp(BoundActivation): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.options = options.get('exp') + self.max_input = 0 + + def forward(self, x): + if self.loss_fusion and self.options != 'no-max-input': + self.max_input = torch.max(x, dim=-1, keepdim=True)[0].detach() + return torch.exp(x - self.max_input) + return torch.exp(x) + + def interval_propagate(self, *v): + assert len(v) == 1 + # unary monotonous functions only + h_L, h_U = v[0] + if self.loss_fusion and self.options != 'no-max-input': + self.max_input = torch.max(h_U, dim=-1, keepdim=True)[0] + h_L, h_U = h_L - self.max_input, h_U - self.max_input + else: + self.max_input = 0 + return torch.exp(h_L), torch.exp(h_U) + + def bound_forward(self, dim_in, x): + m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99) + + exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper) + + kl = exp_m + lw = x.lw * kl.unsqueeze(1) + lb = kl * (x.lb - m + 1) + + ku = (exp_u - exp_l) / (x.upper - x.lower + epsilon) + uw = x.uw * ku.unsqueeze(1) + ub = x.ub * ku - ku * x.lower + exp_l + + return LinearBound(lw, lb, uw, ub) + + def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None): + # Special case when computing log_softmax (FIXME: find a better solution, this trigger condition is not reliable). + if self.loss_fusion and last_lA is None and last_uA is not None and torch.min( + last_uA) >= 0 and x.from_input: + # Adding an extra bias term to the input. This is equivalent to adding a constant and subtract layer before exp. + # Note that we also need to adjust the bias term at the end. + if self.options == 'no-detach': + self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0] + elif self.options != 'no-max-input': + self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0].detach() + else: + self.max_input = 0 + adjusted_lower = x.lower - self.max_input + adjusted_upper = x.upper - self.max_input + # relaxation for upper bound only (used in loss fusion) + exp_l, exp_u = torch.exp(adjusted_lower), torch.exp(adjusted_upper) + k = (exp_u - exp_l) / (adjusted_upper - adjusted_lower + epsilon) + if k.requires_grad: + k = k.clamp(min=1e-6) + uA = last_uA * k.unsqueeze(0) + ubias = last_uA * (-adjusted_lower * k + exp_l).unsqueeze(0) + + if ubias.ndim > 2: + ubias = torch.sum(ubias, dim=tuple(range(2, ubias.ndim))) + # Also adjust the missing ubias term. + if uA.ndim > self.max_input.ndim: + A = torch.sum(uA, dim=tuple(range(self.max_input.ndim, uA.ndim))) + else: + A = uA + + # These should hold true in loss fusion + assert self.batch_dim == 0 + assert A.shape[0] == 1 + + batch_size = A.shape[1] + ubias -= (A.reshape(batch_size, -1) * self.max_input.reshape(batch_size, -1)).sum(dim=-1).unsqueeze(0) + return [(None, uA)], 0, ubias + else: + return super().bound_backward(last_lA, last_uA, x) + + def bound_relax(self, x): + min_val = -1e9 + l, u = x.lower.clamp(min=min_val), x.upper.clamp(min=min_val) + m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99) + exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper) + k = exp_m + self.add_linear_relaxation(mask=None, type='lower', k=k, x0=m, y0=exp_m) + min_val = -1e9 # to avoid (-inf)-(-inf) when both input.lower and input.upper are -inf + epsilon = 1e-20 + close = (u - l < epsilon).int() + k = close * exp_u + (1 - close) * (exp_u - exp_l) / (u - l + epsilon) + self.add_linear_relaxation(mask=None, type='upper', k=k, x0=l, y0=exp_l) + + +class BoundLog(BoundActivation): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + + def forward(self, x): + # NOTE adhoc implementation for loss fusion + if self.loss_fusion: + return torch.logsumexp(self.inputs[0].inputs[0].inputs[0].forward_value, dim=-1) + return torch.log(x.clamp(min=epsilon)) + + def bound_relax(self, x): + rl, ru = self.forward(x.lower), self.forward(x.upper) + ku = (ru - rl) / (x.upper - x.lower + epsilon) + self.add_linear_relaxation(mask=None, type='lower', k=ku, x0=x.lower, y0=rl) + m = (x.lower + x.upper) / 2 + k = torch.reciprocal(m) + rm = self.forward(m) + self.add_linear_relaxation(mask=None, type='upper', k=k, x0=m, y0=rm) + + def interval_propagate(self, *v): + # NOTE adhoc implementation for loss fusion + if self.loss_fusion: + par = self.inputs[0].inputs[0].inputs[0] + lower = torch.logsumexp(par.lower, dim=-1) + upper = torch.logsumexp(par.upper, dim=-1) + return lower, upper + return super().interval_propagate(*v) + + def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None): + A, lbias, ubias = super().bound_backward(last_lA, last_uA, x) + # NOTE adhoc implementation for loss fusion + if self.loss_fusion: + assert A[0][0] is None + exp_module = self.inputs[0].inputs[0] + ubias = ubias + self.get_bias(A[0][1], exp_module.max_input.squeeze(-1)) + return A, lbias, ubias + + +class BoundPow(BoundActivation): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + + def forward(self, x, y): + return torch.pow(x, y) + + def bound_backward(self, last_lA, last_uA, x, y): + assert not self.is_input_perturbed(1) + y = y.lower.item() + if y == int(y) and y == 2: + x_l = x.lower + x_u = torch.max(x.upper, x.lower + 1e-8) + + pow_l = self.forward(x_l, y) + pow_u = self.forward(x_u, y) + k_u = (pow_u - pow_l) / (x_u - x_l).clamp(min=1e-8) + b_u = pow_l - k_u * x_l + + k_l = torch.zeros_like(k_u) + b_l = torch.zeros_like(b_u) + x_m = (x_l + x_u) / 2 + + # TODO this only holds for y=2 + x_m = (x_u < 0) * torch.max(x_m, x_u * 2) + (x_l > 0) * torch.min(x_m, x_l * 2) + k_l = y * self.forward(x_m, y - 1) + b_l = self.forward(x_m, y) - k_l * x_m + + if last_lA is not None: + last_lA_pos, last_lA_neg = last_lA.clamp(min=0), last_lA.clamp(max=0) + lA = last_lA_pos * k_l + last_lA_neg * k_u + lb = self.get_bias(last_lA_pos, b_l) + self.get_bias(last_lA_neg, b_u) + else: + lA, lb = None, 0 + + if last_uA is not None: + last_uA_pos, last_uA_neg = last_uA.clamp(min=0), last_uA.clamp(max=0) + uA = last_uA_pos * k_u + last_uA_neg * k_l + ub = self.get_bias(last_uA_pos, b_u) + self.get_bias(last_uA_neg, b_l) + else: + uA, ub = None, 0 + + return [(lA, uA), (None, None)], lb, ub + else: + raise NotImplementedError(f'Exponent {y} is not supported yet') + + def interval_propagate(self, *v): + assert not self.is_input_perturbed(1) + exp = v[1][0] + assert exp == int(exp) + exp = int(exp) + pl, pu = torch.pow(v[0][0], exp), torch.pow(v[0][1], exp) + if exp % 2 == 1: + return pl, pu + else: + pl, pu = torch.min(pl, pu), torch.max(pl, pu) + mask = 1 - ((v[0][0] < 0) * (v[0][1] > 0)).to(pl.dtype) + return pl * mask, pu + + +class BoundReciprocal(BoundActivation): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + + def forward(self, x): + return torch.reciprocal(x) + + def bound_relax(self, x): + m = (x.lower + x.upper) / 2 + kl = -1 / m.pow(2) + self.add_linear_relaxation(mask=None, type='lower', k=kl, x0=m, y0=1. / m) + ku = -1. / (x.lower * x.upper) + self.add_linear_relaxation(mask=None, type='upper', k=ku, x0=x.lower, y0=1. / x.lower) + + def interval_propagate(self, *v): + h_L, h_U = v[0][0].float(), v[0][1].float() + assert h_L.min() > 0, 'Only positive values are supported in BoundReciprocal' + return torch.reciprocal(h_U), torch.reciprocal(h_L) + + +class BoundSqrt(BoundActivation): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + + def forward(self, x): + return torch.sqrt(x) + + def interval_propagate(self, *v): + return super().interval_propagate(*v) + + def bound_backward(self, last_lA, last_uA, x): + x_l = x.lower + x_u = torch.max(x.upper, x.lower + 1e-8) + sqrt_l = self.forward(x_l) + sqrt_u = self.forward(x_u) + k_l = (sqrt_u - sqrt_l) / (x_u - x_l).clamp(min=1e-8) + b_l = sqrt_l - k_l * x_l + + x_m = (x_l + x_u) / 2 + sqrt_m = self.forward(x_m) + k_u = -0.5 * torch.pow(x_m, -1.5) + b_u = sqrt_m - k_u * x_m + + # TODO make this part a general function + if last_lA is not None: + last_lA_pos, last_lA_neg = last_lA.clamp(min=0), last_lA.clamp(max=0) + lA = last_lA_pos * k_l + last_lA_neg * k_u + lb = self.get_bias(last_lA_pos, b_l) + self.get_bias(last_lA_neg, b_u) + else: + lA, lb = None, 0 + if last_uA is not None: + last_uA_pos, last_uA_neg = last_uA.clamp(min=0), last_uA.clamp(max=0) + uA = last_uA_pos * k_u + last_uA_neg * k_l + ub = self.get_bias(last_uA_pos, b_u) + self.get_bias(last_uA_neg, b_l) + else: + uA, ub = None, 0 + + return [(lA, uA), (None, None)], lb, ub + + +class BoundSqr(BoundActivation): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.nonlinear = True + + def forward(self, x): + return x**2 + + def bound_backward(self, last_lA, last_uA, x): + x_L, x_U = x.lower, x.upper + upper_k = x_U + x_L + upper_b = x_L**2 - upper_k * x_L + if last_uA is not None: + # Special case if we only want the upper bound with non-negative + # coefficients. + if last_uA.min() >= 0: + uA = last_uA * upper_k + ubias = self.get_bias(last_uA, upper_b) + else: + raise NotImplementedError + else: + uA, ubias = None, 0 + if last_lA is not None: + if last_lA.max() <= 0: + lA = last_lA * upper_k + lbias = self.get_bias(last_lA, upper_b) + else: + raise NotImplementedError + else: + lA, lbias = None, 0 + return [(lA, uA)], lbias, ubias + + def interval_propagate(self, *v): + h_L, h_U = v[0][0], v[0][1] + lower = ((h_U < 0) * (h_U**2) + (h_L > 0) * (h_L**2)) + upper = torch.max(h_L**2, h_U**2) + return lower, upper diff --git a/auto_LiRPA/operators/normalization.py b/auto_LiRPA/operators/normalization.py index 75b669b..d1050ed 100644 --- a/auto_LiRPA/operators/normalization.py +++ b/auto_LiRPA/operators/normalization.py @@ -1,23 +1,25 @@ """ Normalization operators""" from .base import * +from .solver_utils import grb class BoundBatchNormalization(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device, training): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options, training): + super().__init__(attr, inputs, output_index, options) self.eps = attr['epsilon'] self.momentum = round(1 - attr['momentum'], 5) # take care! - self.mode = options.get("conv_mode", "matrix") self.options = options.get("bn", {}) # modes: # - forward: use mean and variance estimated from clean forward pass # - ibp: use mean and variance estimated from ibp - self.bn_mode = self.options.get("mode", "forward") + self.bn_mode = self.options.get("mode", "forward") self.use_mean = self.options.get("mean", True) self.use_var = self.options.get("var", True) self.use_affine = self.options.get("affine", True) self.training = training + self.patches_start = True + self.mode = options.get("conv_mode", "matrix") if not self.use_mean or not self.use_var: - logger.info(f'Batch normalization node {self.name}: use_mean {self.use_mean}, use_var {self.use_var}') + logger.info(f'Batch normalization node {self.name}: use_mean {self.use_mean}, use_var {self.use_var}') def _check_unused_mean_or_var(self): # Check if either mean or var is opted out @@ -26,8 +28,9 @@ def _check_unused_mean_or_var(self): if not self.use_var: self.current_var = torch.ones_like(self.current_var) - @Bound.save_io_shape def forward(self, x, w, b, m, v): + if len(x.shape) == 2: + self.patches_start = False if self.training: dim = [0] + list(range(2, x.ndim)) self.current_mean = x.mean(dim) @@ -35,10 +38,10 @@ def forward(self, x, w, b, m, v): else: self.current_mean = m.data self.current_var = v.data - self._check_unused_mean_or_var() + self._check_unused_mean_or_var() if not self.use_affine: w = torch.ones_like(w) - b = torch.zeros_like(b) + b = torch.zeros_like(b) result = F.batch_norm(x, m, v, w, b, self.training, self.momentum, self.eps) if not self.use_mean or not self.use_var: # If mean or variance is disabled, recompute the output from self.current_mean @@ -61,15 +64,15 @@ def bound_backward(self, last_lA, last_uA, *x): self._check_unused_mean_or_var() if not self.use_affine: weight = torch.ones_like(weight) - bias = torch.zeros_like(bias) - + bias = torch.zeros_like(bias) + tmp_bias = bias - self.current_mean / torch.sqrt(self.current_var + self.eps) * weight tmp_weight = weight / torch.sqrt(self.current_var + self.eps) def _bound_oneside(last_A): if last_A is None: return None, 0 - if type(last_A) == torch.Tensor: + if type(last_A) == Tensor: next_A = last_A * tmp_weight.view(*((1, 1, -1) + (1,) * (last_A.ndim - 3))) if last_A.ndim > 3: sum_bias = (last_A.sum(tuple(range(3, last_A.ndim))) * tmp_bias).sum(2) @@ -85,11 +88,12 @@ def _bound_oneside(last_A): # tmp_weight has shape (c,), it will be applied on the (c,) dimension. patches = patches * tmp_weight.view(*([1] * (patches.ndim - 3)), -1, 1, 1) # Match with sparse or non-sparse patches. next_A = Patches(patches, last_A.stride, last_A.padding, last_A.shape, identity=0, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape) - + # bias to size (c,), need expansion before unfold. bias = tmp_bias.view(-1,1,1).expand(self.input_shape[1:]).unsqueeze(0) # Unfolded bias has shape (1, out_h, out_w, in_c, H, W). - bias_unfolded = inplace_unfold(bias, kernel_size=last_A.patches.shape[-2:], padding=last_A.padding, stride=last_A.stride) + bias_unfolded = inplace_unfold(bias, kernel_size=last_A.patches.shape[-2:], padding=last_A.padding, stride=last_A.stride, + inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding) if last_A.unstable_idx is not None: # Sparse bias has shape (unstable_size, batch, in_c, H, W). bias_unfolded = bias_unfolded[:, last_A.unstable_idx[1], last_A.unstable_idx[2]] @@ -130,6 +134,7 @@ def _bound_oneside(last_A): return [(lA, uA), (None, None), (None, None), (None, None), (None, None)], lbias, ubias + def interval_propagate(self, *v): assert not self.is_input_perturbed(1) and not self.is_input_perturbed(2), \ 'Weight perturbation is not supported for BoundBatchNormalization' @@ -153,17 +158,75 @@ def interval_propagate(self, *v): self._check_unused_mean_or_var() if not self.use_affine: weight = torch.ones_like(weight) - bias = torch.zeros_like(bias) + bias = torch.zeros_like(bias) tmp_weight = weight / torch.sqrt(self.current_var + self.eps) tmp_weight_abs = tmp_weight.abs() tmp_bias = bias - self.current_mean * tmp_weight - shape = (1, -1) + (1,) * (mid.ndim - 2) - center = tmp_weight.view(*shape) * mid + tmp_bias.view(*shape) - deviation = tmp_weight_abs.view(*shape) * diff - lower = center - deviation - upper = center + deviation + + # interval_propagate() of the Linear layer may encounter input with different norms. + norm, eps = Interval.get_perturbation(v[0])[:2] + if norm == np.inf: + center = tmp_weight.view(*shape) * mid + tmp_bias.view(*shape) + deviation = tmp_weight_abs.view(*shape) * diff + elif norm > 0: + mid = v[0][0] + center = tmp_weight.view(*shape) * mid + tmp_bias.view(*shape) + if norm == 2: + ptb = copy.deepcopy(v[0].ptb) + ptb.eps = eps * tmp_weight_abs.max() + return Interval(center, center, ptb=ptb) + else: + # General Lp norm. + center = tmp_weight.view(*shape) * mid + deviation = tmp_weight_abs.view(*shape) * eps # use a Linf ball to replace Lp norm + else: + raise NotImplementedError + + lower, upper = center - deviation, center + deviation return lower, upper + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + # e.g., last layer input gurobi vars (3,32,32) + gvars_array = np.array(v[0]) + # pre_layer_shape (1,3,32,32) + pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape + # this layer shape (1,8,16,16) + this_layer_shape = self.output_shape + + weight, bias = v[1], v[2] + + self.current_mean = v[3] + self.current_var = v[4] + self._check_unused_mean_or_var() + if not self.use_affine: + weight = torch.ones_like(weight) + bias = torch.zeros_like(bias) + + tmp_bias = bias - self.current_mean / torch.sqrt(self.current_var + self.eps) * weight + tmp_weight = weight / torch.sqrt(self.current_var + self.eps) + + new_layer_gurobi_vars = [] + neuron_idx = 0 + for out_chan_idx in range(this_layer_shape[1]): + out_chan_vars = [] + for out_row_idx in range(this_layer_shape[2]): + out_row_vars = [] + for out_col_idx in range(this_layer_shape[3]): + # print(this_layer_bias.shape, out_chan_idx, out_lbs.size(1)) + lin_expr = tmp_bias[out_chan_idx].item() + tmp_weight[out_chan_idx].item() * gvars_array[out_chan_idx, out_row_idx, out_col_idx] + var = model.addVar(lb=-float('inf'), ub=float('inf'), + obj=0, vtype=grb.GRB.CONTINUOUS, + name=f'lay{self.name}_{neuron_idx}') + model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq') + neuron_idx += 1 + + out_row_vars.append(var) + out_chan_vars.append(out_row_vars) + new_layer_gurobi_vars.append(out_chan_vars) + + self.solver_vars = new_layer_gurobi_vars + # self.solver_constrs = new_layer_gurobi_constrs + model.update() diff --git a/auto_LiRPA/operators/pooling.py b/auto_LiRPA/operators/pooling.py new file mode 100644 index 0000000..6163caa --- /dev/null +++ b/auto_LiRPA/operators/pooling.py @@ -0,0 +1,550 @@ +""" Convolution, pooling and padding operators""" +from multiprocessing.sharedctypes import Value +from .base import * +from .activation import BoundOptimizableActivation +import numpy as np +from .solver_utils import grb + + +class BoundMaxPool(BoundOptimizableActivation): + #FIXME clean up needed + + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2]) + assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3]) + + self.requires_input_bounds = [0] + self.kernel_size = attr['kernel_shape'] + self.stride = attr['strides'] + self.padding = [attr['pads'][0], attr['pads'][1]] + self.ceil_mode = False + self.use_default_ibp = True + self.alpha = None + self.init = {} + self.alpha_batch_dim = 2 + + def forward(self, x): + output, _ = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + return output + + def project_simplex(self, patches): + sorted = torch.flatten(patches, -2) + sorted, _ = torch.sort(sorted, -1, descending=True) + rho_sum = torch.cumsum(sorted, -1) + rho_value = 1 - rho_sum + rho_value = (sorted + rho_value/torch.tensor(range(1, sorted.size(-1)+1), dtype=torch.float, device=sorted.device)) > 0 + _, rho_index = torch.max(torch.cumsum(rho_value, -1), -1) + rho_sum = torch.gather(rho_sum, -1, rho_index.unsqueeze(-1)).squeeze(-1) + lbd = 1/(rho_index+1)* (1-rho_sum) + + return torch.clamp(patches + lbd.unsqueeze(-1).unsqueeze(-1), min=0) + + def init_opt_parameters(self, start_nodes): + self.alpha = OrderedDict() + ref = self.inputs[0].lower # a reference variable for getting the shape + for ns, size_s, unstable_idx in start_nodes: + if ns == '_forward': + warnings.warn("MaxPool's optimization is not supported for forward mode") + continue + self.alpha[ns] = torch.empty( + [1, size_s, self.input_shape[0], self.input_shape[1], + self.output_shape[-2], self.output_shape[-1], + self.kernel_size[0], self.kernel_size[1]], + dtype=torch.float, device=ref.device, requires_grad=True) + self.init[ns] = False + + @staticmethod + @torch.jit.script + def jit_mutiply(Apos, Aneg, pos, neg): + return pos.contiguous() * Apos + neg.contiguous() * Aneg + + def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None, unstable_idx=None): + # self.padding is a tuple of two elements: (height dimension padding, width dimension padding). + paddings = tuple((self.padding[0], self.padding[0], self.padding[1], self.padding[1])) + + if self.stride[0] != self.kernel_size[0]: + raise ValueError("self.stride ({}) != self.kernel_size ({})".format(self.stride, self.kernel_size)) + if self.padding[0] != 0: + raise ValueError("BoundMaxPool doesn't support padding != 0") + + shape = self.input_shape + batch_size = x.lower.shape[0] + shape = list(shape[:-2]) + [a + 2*b for a, b in zip(self.input_shape[-2:], self.padding)] + shape[0] = batch_size + # Lower and upper D matrices. They have size (batch_size, input_c, x, y) which will be multiplied on enlarges the A matrices via F.interpolate. + upper_d = torch.zeros(shape, device=x.device) + lower_d = None + + # Size of upper_b and lower_b: (batch_size, output_c, h, w). + upper_b = torch.zeros(batch_size, *self.output_shape[1:], device=x.device) + lower_b = torch.zeros(batch_size, *self.output_shape[1:], device=x.device) + + # Find the maxpool neuron whose input bounds satisfy l_i > max_j u_j for all j != i. In this case, the maxpool neuron is linear, and we can set upper_d = lower_d = 1. + # We first find which indices has the largest lower bound. + max_lower, max_lower_index = F.max_pool2d(x.lower, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + # Set the upper bound of the i-th input to -inf so it will not be selected as the max. + + if paddings == (0,0,0,0): + delete_upper = torch.scatter( + torch.flatten(x.upper, -2), -1, + torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape) + else: + delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape) + # Find the the max upper bound over the remaining ones. + max_upper, _ = F.max_pool2d(delete_upper, self.kernel_size, self.stride, 0, return_indices=True, ceil_mode=self.ceil_mode) + + # The upper bound slope for maxpool is either 1 on input satisfies l_i > max_j u_j (linear), or 0 everywhere. Upper bound is not optimized. + values = torch.zeros_like(max_lower) + values[max_lower >= max_upper] = 1.0 + upper_d = torch.scatter(torch.flatten(upper_d, -2), -1, torch.flatten(max_lower_index, -2), torch.flatten(values, -2)).view(upper_d.shape) + + if self.opt_stage == 'opt': + if unstable_idx is not None and self.alpha[start_node.name].size(1) != 1: + if unstable_idx.ndim == 1: + # Only unstable neurons of the start_node neurons are used. + alpha = self.non_deter_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) + elif unstable_idx.ndim == 2: + # Each element in the batch selects different neurons. + alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) + else: + raise ValueError + else: + alpha = self.alpha[start_node.name] + + if not self.init[start_node.name]: + lower_d = torch.zeros((shape), device=x.device) + # [batch, C, H, W] + lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) + # shape [batch, C*k*k, L] + lower_d_unfold = F.unfold(lower_d, self.kernel_size, 1, stride=self.stride) + + # [batch, C, k, k, out_H, out_W] + alpha_data = lower_d_unfold.view(lower_d.shape[0], lower_d.shape[1], self.kernel_size[0], self.kernel_size[1], self.output_shape[-2], self.output_shape[-1]) + + # [batch, C, out_H, out_W, k, k] + alpha.data.copy_(alpha_data.permute((0,1,4,5,2,3)).clone().detach()) + self.init[start_node.name] = True + # In optimization mode, we use the same lower_d once builded. + if self.padding[0] > 0 or self.padding[1] > 0: + lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[1]:-self.padding[1]] + # The lower bound coefficients must be positive and projected to an unit simplex. + alpha.data = self.project_simplex(alpha.data).clone().detach() # TODO: don't do this, never re-assign the .data property. Use copy_ instead. + # permute the last 6 dimensions of alpha to [batch, C, k, k, out_H, out_W], which prepares for the unfold operation. + alpha = alpha.permute((0,1,2,3,6,7,4,5)) + alpha_shape = alpha.shape + alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2], -1, alpha_shape[-2]*alpha_shape[-1])) + lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1, self.padding, self.stride) + lower_d = lower_d.view(alpha_shape[0], alpha_shape[1], alpha_shape[2], *lower_d.shape[1:]) + lower_d = lower_d.squeeze(0) + else: + lower_d = torch.zeros((shape), device=x.device) + # Not optimizable bounds. We simply set \hat{z} >= z_i where i is the input element with largest lower bound. + lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) + if self.padding[0] > 0 or self.padding[1] > 0: + lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[1]:-self.padding[1]] + + # For the upper bound, we set the bias term to concrete upper bounds for maxpool neurons that are not linear. + max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + upper_b[max_upper > max_lower] = max_upper_[max_upper > max_lower] + + def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg): + if last_A is None: + return None, 0 + + bias = 0 + + if isinstance(last_A, torch.Tensor): + pos_A = last_A.clamp(min=0) + neg_A = last_A.clamp(max=0) + + if b_pos is not None: + # This is matrix mode, and padding is considered in the previous layers + bias = bias + self.get_bias(pos_A, b_pos) + if b_neg is not None: + bias = bias + self.get_bias(neg_A, b_neg) + + # Here we should comfirm that the maxpool patches are not overlapped. + shape = last_A.size() + pos_A = F.interpolate(pos_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size) + if self.input_shape[-2] != pos_A.shape[-2] and self.input_shape[-1] != pos_A.shape[-1]: + pos_A = F.pad(pos_A, (0, self.input_shape[-2] - pos_A.shape[-2], 0, self.input_shape[-1] - pos_A.shape[-1])) + pos_A = pos_A.view(shape[0], shape[1], *pos_A.shape[1:]) + + neg_A = F.interpolate(neg_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size) + if self.input_shape[-2] != neg_A.shape[-2] and self.input_shape[-1] != neg_A.shape[-1]: + neg_A = F.pad(neg_A, (0, self.input_shape[-2] - neg_A.shape[-2], 0, self.input_shape[-1] - neg_A.shape[-1])) + neg_A = neg_A.view(shape[0], shape[1], *neg_A.shape[1:]) + + next_A = self.jit_mutiply(pos_A, neg_A, d_pos, d_neg) + elif isinstance(last_A, Patches): + if b_pos is not None: + patch_pos = Patches( + last_A.patches.clamp(min=0), last_A.stride, last_A.padding, + last_A.shape, unstable_idx=last_A.unstable_idx, + output_shape=last_A.output_shape) + bias = bias + self.get_bias(patch_pos, b_pos) + if b_neg is not None: + patch_neg = Patches( + last_A.patches.clamp(max=0), last_A.stride, last_A.padding, + last_A.shape, unstable_idx=last_A.unstable_idx, + output_shape=last_A.output_shape) + bias = bias + self.get_bias(patch_neg, b_neg) + + # bias = bias.transpose(0,1) + shape = last_A.shape + pos_A = last_A.patches.clamp(min=0) + neg_A = last_A.patches.clamp(max=0) + + def upsample(last_patches, last_A): + if last_A.unstable_idx is None: + patches = F.interpolate( + last_patches.view(shape[0] * shape[1] * shape[2], *shape[3:]), + scale_factor=[1,]+self.kernel_size) + patches = patches.view(shape[0], shape[1], shape[2], *patches.shape[1:]) + else: + patches = F.interpolate( + last_patches, scale_factor=[1,] + self.kernel_size) + return Patches( + patches, stride=last_A.stride, padding=last_A.padding, + shape=patches.shape, unstable_idx=last_A.unstable_idx, + output_shape=last_A.output_shape) + + pos_A = upsample(pos_A, last_A) + neg_A = upsample(neg_A, last_A) + + stride = self.stride[0] * last_A.stride + if isinstance(last_A.padding, int): + padding = last_A.padding * self.stride[0] + self.padding[0] + else: + # Here we need to unfold the d_pos to match pos_A and neg_A patches + # And we compute the padding and stride of pos_A and neg_A + padding = tuple(a * self.stride[0] + self.padding[0] for a in last_A.padding) + + padding, stride, output_padding = compute_patches_stride_padding( + self.input_shape, last_A.padding, last_A.stride, self.padding, self.stride, last_A.inserted_zeros, last_A.output_padding) + + pos_A.padding, pos_A.stride, pos_A.output_padding = padding, stride, output_padding + neg_A.padding, neg_A.stride, neg_A.output_padding = padding, stride, output_padding + + # unsqueeze for the spec dimension + d_pos = maybe_unfold_patches(d_pos.unsqueeze(0), pos_A) + d_neg = maybe_unfold_patches(d_neg.unsqueeze(0), neg_A) + + + next_A_patches = self.jit_mutiply(pos_A.patches, neg_A.patches, d_pos, d_neg) + + if start_node is not None: + self.patch_size[start_node.name] = next_A_patches.size() + + + next_A = Patches( + next_A_patches, stride, padding, next_A_patches.shape, + unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape, + inserted_zeros=last_A.inserted_zeros, output_padding=output_padding) + + return next_A, bias + + if self.padding[0] > 0: + upper_d = upper_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] + + uA, ubias = _bound_oneside(last_uA, upper_d, lower_d, upper_b, lower_b) + lA, lbias = _bound_oneside(last_lA, lower_d, upper_d, lower_b, upper_b) + + return [(lA, uA)], lbias, ubias + + def bound_forward(self, dim_in, x): + lower_d, lower_b, upper_d, upper_b = self.bound_relax(x) + + def _bound_oneside(w_pos, b_pos, w_neg, b_neg, d, b): + d_pos, d_neg = d.clamp(min=0), d.clamp(max=0) + w_new = d_pos.unsqueeze(1) * w_pos + d_neg.unsqueeze(1) * w_neg + b_new = d_pos * b_pos + d_neg * b_neg + if isinstance(self.kernel_size, list) and len(self.kernel_size) == 2: + tot_kernel_size = prod(self.kernel_size) + elif isinstance(self.kernel_size, int): + tot_kernel_size = self.kernel_size ** 2 + else: + raise ValueError(f'Unsupported kernel size {self.kernel_size}') + w_pooled = (F.avg_pool2d(w_new.view(-1, *w_new.shape[2:]), + self.kernel_size, self.stride, self.padding, + ceil_mode=self.ceil_mode) * tot_kernel_size) + w_pooled = w_pooled.reshape(w_new.shape[0], -1, *w_pooled.shape[1:]) + b_pooled = F.avg_pool2d(b_new, self.kernel_size, self.stride, self.padding, + ceil_mode=self.ceil_mode) * tot_kernel_size + b + return w_pooled, b_pooled + + lw, lb = _bound_oneside(x.lw, x.lb, x.uw, x.ub, lower_d, lower_b) + uw, ub = _bound_oneside(x.uw, x.ub, x.lw, x.lb, upper_d, upper_b) + + return LinearBound(lw, lb, uw, ub) + + def bound_relax(self, x): + # Only used by forward mode + paddings = tuple(self.padding + self.padding) + self.upper, self.lower = x.upper, x.lower + + # A_shape = last_lA.shape if last_lA is not None else last_uA.shape + # batch_size, input_c, x, y + upper_d = torch.zeros_like(x.lower) + lower_d = torch.zeros_like(x.lower) + + upper_d = F.pad(upper_d, paddings) + lower_d = F.pad(lower_d, paddings) + + # batch_size, output_c, x, y + upper_b = torch.zeros((list(self.output_shape))).to(x.lower) + lower_b = torch.zeros((list(self.output_shape))).to(x.lower) + + # 1. find the index i where li > uj for all j, then set upper_d = lower_d = 1 + max_lower, max_lower_index = F.max_pool2d(x.lower, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape) + max_upper, _ = F.max_pool2d(delete_upper, self.kernel_size, self.stride, 0, return_indices=True, ceil_mode=self.ceil_mode) + + values = torch.zeros_like(max_lower) + values[max_lower >= max_upper] = 1.0 + upper_d = torch.scatter(torch.flatten(upper_d, -2), -1, torch.flatten(max_lower_index, -2), torch.flatten(values, -2)).view(upper_d.shape) + + # FIXME shape error + if False and self.opt_stage == 'opt': + alpha = self.alpha[self._start] + + if self.init[self._start] == False: + lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) + lower_d_unfold = F.unfold(lower_d, self.kernel_size, 1, stride=self.stride) + + alpha_data = lower_d_unfold.view(lower_d.shape[0], lower_d.shape[1], self.kernel_size[0], self.kernel_size[1], self.output_shape[-2], self.output_shape[-1]) + alpha.data.copy_(alpha_data.permute((0,1,4,5,2,3)).clone().detach()) + self.init[self._start] = True + if self.padding[0] > 0: + lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] + + alpha.data = self.project_simplex(alpha.data).clone().detach() + alpha = alpha.permute((0,1,2,3,6,7,4,5)) + alpha_shape = alpha.shape + alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2], -1, alpha_shape[-2]*alpha_shape[-1])) + lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1, self.padding, self.stride) + lower_d = lower_d.view(alpha_shape[0], alpha_shape[1], alpha_shape[2], *lower_d.shape[1:]) + lower_d = lower_d.squeeze(0) + else: + lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) + if self.padding[0] > 0: + lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] + + values[:] = 0.0 + max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + values[max_upper > max_lower] = max_upper_[max_upper > max_lower] + upper_b = values + + if self.padding[0] > 0: + upper_d = upper_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] + + return lower_d, lower_b, upper_d, upper_b + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + # e.g., last layer input gurobi vars (3,32,32) + gvars_array = np.array(v[0]) + # pre_layer_shape (1,32,27,27) + pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape + # this layer shape (1,32,6,6) + this_layer_shape = self.output_shape + assert this_layer_shape[2] == ((2 * self.padding[0] + pre_layer_shape[2] - (self.stride[0] - 1))//self.stride[0]) + + new_layer_gurobi_vars = [] + neuron_idx = 0 + pre_ubs = self.forward(self.inputs[0].upper).detach().cpu().numpy() + + for out_chan_idx in range(this_layer_shape[1]): + out_chan_vars = [] + for out_row_idx in range(this_layer_shape[2]): + out_row_vars = [] + for out_col_idx in range(this_layer_shape[3]): + a_sum = 0.0 + v = model.addVar(lb=-float('inf'), ub=float('inf'), + obj=0, vtype=grb.GRB.CONTINUOUS, + name=f'lay{self.name}_{neuron_idx}') + for ker_row_idx in range(self.kernel_size[0]): + in_row_idx = -self.padding[0] + self.stride[0] * out_row_idx + ker_row_idx + if (in_row_idx < 0) or (in_row_idx == len(gvars_array[out_chan_idx][ker_row_idx])): + # This is padding -> value of 0 + continue + for ker_col_idx in range(self.kernel_size[1]): + in_col_idx = -self.padding[1] + self.stride[1] * out_col_idx + ker_col_idx + if (in_col_idx < 0) or (in_col_idx == pre_layer_shape[3]): + # This is padding -> value of 0 + continue + var = gvars_array[out_chan_idx][in_row_idx][in_col_idx] + a = model.addVar(vtype=grb.GRB.BINARY) + a_sum += a + model.addConstr(v >= var) + model.addConstr(v <= var + (1 - a) * pre_ubs[0, out_chan_idx, out_row_idx, out_col_idx]) + model.addConstr(a_sum == 1, name=f'lay{self.name}_{neuron_idx}_eq') + out_row_vars.append(v) + out_chan_vars.append(out_row_vars) + new_layer_gurobi_vars.append(out_chan_vars) + + self.solver_vars = new_layer_gurobi_vars + model.update() + + + +class BoundGlobalAveragePool(Bound): + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + + def forward(self, x): + output = nn.AdaptiveAvgPool2d((1, 1)).forward(x) # adaptiveAveragePool with output size (1, 1) + return output + + def bound_backward(self, last_lA, last_uA, x): + H, W = self.input_shape[-2], self.input_shape[-1] + + lA = (last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)) if last_lA is not None else None + uA = (last_uA.expand(list(last_uA.shape[:-2]) + [H, W]) / (H * W)) if last_uA is not None else None + + return [(lA, uA)], 0, 0 + + def interval_propagate(self, *v): + h_L, h_U = v[0] + h_L = F.adaptive_avg_pool2d(h_L, (1, 1)) + h_U = F.adaptive_avg_pool2d(h_U, (1, 1)) + return h_L, h_U + + +class BoundAveragePool(Bound): + def __init__(self, attr, inputs, output_index, options): + # assumptions: ceil_mode=False, count_include_pad=True + super().__init__(attr, inputs, output_index, options) + + assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2]) + assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3]) + + self.kernel_size = attr['kernel_shape'] + assert len(self.kernel_size) == 2 + self.stride = attr['strides'] + assert len(self.stride) == 2 + # FIXME (22/07/02): padding is inconsistently handled. Should use 4-tuple. + + if 'pads' not in attr: + self.padding = [0, 0] + else: + self.padding = [attr['pads'][0], attr['pads'][1]] + self.ceil_mode = False + self.count_include_pad = True + self.use_default_ibp = True + + def forward(self, x): + return F.avg_pool2d(x, self.kernel_size, self.stride, + self.padding, self.ceil_mode, self.count_include_pad) + + def bound_backward(self, last_lA, last_uA, x): + def _bound_oneside(last_A): + if last_A is None: + return None, 0 + if isinstance(last_A, torch.Tensor): + shape = last_A.size() + # propagate A to the next layer, with batch concatenated together + next_A = F.interpolate(last_A.view(shape[0] * shape[1], *shape[2:]), + scale_factor=self.kernel_size) / (prod(self.kernel_size)) + next_A = F.pad(next_A, (0, self.input_shape[-2] - next_A.shape[-2], 0, self.input_shape[-1] - next_A.shape[-1])) + next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:]) + elif isinstance(last_A, Patches): + patches = last_A.patches + shape = patches.size() + # When the number of inserted zeros can cancel out the stride, we use a shortcut that can reduce computation. + simplify_patch = (last_A.inserted_zeros + 1 == self.kernel_size[0]) and (self.kernel_size[0] == self.kernel_size[1]) + padding, stride, output_padding = compute_patches_stride_padding( + self.input_shape, last_A.padding, last_A.stride, self.padding, self.stride, + inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding, simplify=not simplify_patch) + inserted_zeros = last_A.inserted_zeros + if last_A.inserted_zeros == 0: + # No inserted zeros, can be handled using interpolate. + if last_A.unstable_idx is None: + # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W] + up_sampled_patches = F.interpolate(patches.view(shape[0] * shape[1], shape[2] * shape[3], *shape[4:]), scale_factor=[1,] + self.kernel_size) + # The dimension of patch-H and patch_W has changed. + up_sampled_patches = up_sampled_patches.view(*shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1)) + else: + # shape is: [spec, batch, in_c, patch_H, patch_W] + up_sampled_patches = F.interpolate(patches, scale_factor=[1,] + self.kernel_size) + # Divided by the averaging factor. + up_sampled_patches = up_sampled_patches / prod(self.kernel_size) + elif simplify_patch: + padding = tuple(p // s - o for p, s, o in zip(padding, stride, output_padding)) + output_padding = (0, 0, 0, 0) + stride = 1 # Stride and inserted zero canceled out. No need to insert zeros and add output_padding. + inserted_zeros = 0 + value = 1. / prod(self.kernel_size) + # In the case where the stride and adding_zeros cancel out, we do not need to insert zeros. + weight = torch.full(size=(self.input_shape[1], 1, *self.kernel_size), fill_value=value, dtype=patches.dtype, device=patches.device) + if last_A.unstable_idx is None: + # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W] + up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1] * shape[2] * shape[3], *shape[4:]), weight, stride=1, groups=self.input_shape[1]) + else: + # shape is: [spec, batch, in_c, patch_H, patch_W] + up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1], *shape[2:]), weight, stride=1, groups=self.input_shape[1]) + up_sampled_patches = up_sampled_patches.view(*shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1)) + else: + # With inserted zeros, must be handled by treating pooling as general convolution. + value = 1. / prod(self.kernel_size) + weight = torch.full(size=(self.input_shape[1], 1, *self.kernel_size), fill_value=value, dtype=patches.dtype, device=patches.device) + weight = insert_zeros(weight, last_A.inserted_zeros) + if last_A.unstable_idx is None: + # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W] + up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1] * shape[2] * shape[3], *shape[4:]), weight, stride=self.kernel_size, groups=self.input_shape[1]) + else: + # shape is: [spec, batch, in_c, patch_H, patch_W] + up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1], *shape[2:]), weight, stride=self.kernel_size, groups=self.input_shape[1]) + up_sampled_patches = up_sampled_patches.view(*shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1)) + next_A = last_A.create_similar(up_sampled_patches, stride=stride, padding=padding, output_padding=output_padding, inserted_zeros=inserted_zeros) + else: + raise ValueError(f'last_A has unexpected shape {type(last_A)}') + return next_A, 0. + + lA, lbias = _bound_oneside(last_lA) + uA, ubias = _bound_oneside(last_uA) + return [(lA, uA)], lbias, ubias + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + # e.g., last layer input gurobi vars (3,32,32) + gvars_array = np.array(v[0]) + # pre_layer_shape (1,32,27,27) + pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape + # this layer shape (1,32,6,6) + this_layer_shape = self.output_shape + assert this_layer_shape[2] == ((2 * self.padding[0] + pre_layer_shape[2] - (self.stride[0] - 1))//self.stride[0]) + + value = 1.0/(self.kernel_size[0] * self.kernel_size[1]) + new_layer_gurobi_vars = [] + neuron_idx = 0 + for out_chan_idx in range(this_layer_shape[1]): + out_chan_vars = [] + for out_row_idx in range(this_layer_shape[2]): + out_row_vars = [] + for out_col_idx in range(this_layer_shape[3]): + # print(self.bias.shape, out_chan_idx, out_lbs.size(1)) + lin_expr = 0.0 + for ker_row_idx in range(self.kernel_size[0]): + in_row_idx = -self.padding[0] + self.stride[0] * out_row_idx + ker_row_idx + if (in_row_idx < 0) or (in_row_idx == len(gvars_array[out_chan_idx][ker_row_idx])): + # This is padding -> value of 0 + continue + for ker_col_idx in range(self.kernel_size[1]): + in_col_idx = -self.padding[1] + self.stride[1] * out_col_idx + ker_col_idx + if (in_col_idx < 0) or (in_col_idx == pre_layer_shape[3]): + # This is padding -> value of 0 + continue + coeff = value + lin_expr += coeff * gvars_array[out_chan_idx][in_row_idx][in_col_idx] + v = model.addVar(lb=-float('inf'), ub=float('inf'), + obj=0, vtype=grb.GRB.CONTINUOUS, + name=f'lay{self.name}_{neuron_idx}') + model.addConstr(lin_expr == v, name=f'lay{self.name}_{neuron_idx}_eq') + neuron_idx += 1 + + out_row_vars.append(v) + out_chan_vars.append(out_row_vars) + new_layer_gurobi_vars.append(out_chan_vars) + + self.solver_vars = new_layer_gurobi_vars + model.update() diff --git a/auto_LiRPA/operators/reduce.py b/auto_LiRPA/operators/reduce.py index 246cc98..51da427 100644 --- a/auto_LiRPA/operators/reduce.py +++ b/auto_LiRPA/operators/reduce.py @@ -3,8 +3,8 @@ class BoundReduceMax(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.axis = attr['axes'] # for torch.max, `dim` must be an int if isinstance(self.axis, list): @@ -18,10 +18,8 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio in Softmax of Transformers.""" self.fixed_max_index = options.get('fixed_reducemax_index', False) - @Bound.save_io_shape def forward(self, x): - if self.axis < 0: - self.axis += len(self.input_shape) + self.axis = self.make_axis_non_negative(self.axis) assert self.axis > 0 res = torch.max(x, dim=self.axis, keepdim=self.keepdim) self.indices = res.indices @@ -53,20 +51,19 @@ def _bound_oneside(last_A): class BoundReduceMean(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.axis = attr['axes'] self.keepdim = bool(attr['keepdims']) if 'keepdims' in attr else True self.use_default_ibp = True - @Bound.save_io_shape def forward(self, x): return torch.mean(x, dim=self.axis, keepdim=self.keepdim) def bound_backward(self, last_lA, last_uA, x): for i in range(len(self.axis)): if self.axis[i] < 0: - self.axis[i] = len(self.input_shape) + self.axis[i] + self.axis[i] = self.make_axis_non_negative(self.axis[i]) assert self.axis[i] > 0 def _bound_oneside(last_A): @@ -89,9 +86,7 @@ def _bound_oneside(last_A): def bound_forward(self, dim_in, x): assert (self.keepdim) assert (len(self.axis) == 1) - axis = self.axis[0] - if axis < 0: - axis = len(self.input_shape) + axis + axis = self.make_axis_non_negative(self.axis[0]) assert (axis > 0) size = self.input_shape[axis] lw = x.lw.sum(dim=axis + 1, keepdim=True) / size @@ -107,13 +102,12 @@ def infer_batch_dim(self, batch_size, *x): return x[0] class BoundReduceSum(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.axis = attr['axes'] if 'axes' in attr else None self.keepdim = bool(attr['keepdims']) self.use_default_ibp = True - @Bound.save_io_shape def forward(self, x): if self.axis is not None: return torch.sum(x, dim=self.axis, keepdim=self.keepdim) @@ -143,14 +137,13 @@ def _bound_oneside(last_A): return [(_bound_oneside(last_lA), _bound_oneside(last_uA))], 0, 0 def bound_forward(self, dim_in, x): - assert self.keepdim assert len(self.axis) == 1 - axis = self.axis[0] - if axis < 0: - axis = len(self.input_shape) + axis - assert (axis > 0) - lw, lb = x.lw.sum(dim=axis + 1, keepdim=True), x.lb.sum(dim=axis, keepdim=True) - uw, ub = x.uw.sum(dim=axis + 1, keepdim=True), x.ub.sum(dim=axis, keepdim=True) + axis = self.make_axis_non_negative(self.axis[0]) + assert axis > 0 + lw = x.lw.sum(dim=axis + 1, keepdim=self.keepdim) + lb = x.lb.sum(dim=axis, keepdim=self.keepdim) + uw = x.uw.sum(dim=axis + 1, keepdim=self.keepdim) + ub = x.ub.sum(dim=axis, keepdim=self.keepdim) return LinearBound(lw, lb, uw, ub) def infer_batch_dim(self, batch_size, *x): diff --git a/auto_LiRPA/operators/rnn.py b/auto_LiRPA/operators/rnn.py index 4d59f27..7330bf6 100644 --- a/auto_LiRPA/operators/rnn.py +++ b/auto_LiRPA/operators/rnn.py @@ -3,12 +3,11 @@ class BoundRNN(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.complex = True self.output_index = output_index - @Bound.save_io_shape def forward(self, x, weight_input, weight_recurrent, bias, sequence_length, initial_h): assert (torch.sum(torch.abs(initial_h)) == 0) @@ -17,7 +16,7 @@ def forward(self, x, weight_input, weight_recurrent, bias, sequence_length, init class BoundRNNImpl(nn.Module): def __init__(self, input_size, hidden_size, - weight_input, weight_recurrent, bias, output_index, options, device): + weight_input, weight_recurrent, bias, output_index, options): super().__init__() self.input_size = input_size @@ -38,7 +37,7 @@ def __init__(self, input_size, hidden_size, def forward(self, x): length = x.shape[0] outputs = [] - hidden = torch.zeros(x.shape[1], self.hidden_size, device=self.device) + hidden = torch.zeros(x.shape[1], self.hidden_size).to(x) for i in range(length): hidden = self.cell(x[i, :], hidden) outputs.append(hidden.unsqueeze(0)) @@ -52,7 +51,7 @@ def forward(self, x): self.model = BoundRNNImpl( self.input_size, self.hidden_size, weight_input, weight_recurrent, bias, - self.output_index, self.device) + self.output_index) self.input = (x,) return self.model(self.input) \ No newline at end of file diff --git a/auto_LiRPA/operators/shape.py b/auto_LiRPA/operators/shape.py index 1f34b29..7b9c4b2 100644 --- a/auto_LiRPA/operators/shape.py +++ b/auto_LiRPA/operators/shape.py @@ -1,26 +1,60 @@ """ Shape operators """ from .base import * +from ..patches import Patches, patches_to_matrix +from .linear import BoundLinear class BoundReshape(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + # It can be set to `view`, so that `view` instead of `reshape` will be used. + self.option = options.get('reshape', 'reshape') - @Bound.save_io_shape def forward(self, x, shape): shape = list(shape) for i in range(len(shape)): if shape[i] == -1: shape[i] = prod(x.shape) // int(prod(shape[:i]) * prod(shape[(i + 1):])) self.shape = shape - return x.reshape(shape) + if self.option == 'view': + return x.contiguous().view(shape) + else: + return x.reshape(shape) def bound_backward(self, last_lA, last_uA, x, shape): def _bound_oneside(A): if A is None: return None - # A shape is (spec, batch, *node_shape) - return A.reshape(A.shape[0], A.shape[1], *self.input_shape[1:]) - + if type(A) == Patches: + if type(self.inputs[0]) == BoundLinear: + # Save the shape and it will be converted to matrix in Linear layer. + return A.create_similar(input_shape=self.output_shape) + if A.unstable_idx is None: + patches = A.patches + # non-sparse: [batch, out_dim, out_c, out_H, out_W, out_dim, in_c, H, W] + # [batch, out_dim*out_c, out_H, out_W, out_dim*in_c, H, W] + patches = patches.reshape( + patches.shape[0], + patches.shape[1]*patches.shape[2], patches.shape[3], patches.shape[4], + patches.shape[5]*patches.shape[6], patches.shape[7], patches.shape[8]) + # expected next_A shape [batch, spec, in_c, in_H , in_W]. + next_A = patches_to_matrix( + patches, [ + self.input_shape[0]*self.input_shape[1], + patches.shape[-3], + int(math.sqrt(self.input_shape[-1]//A.patches.shape[-3])), + int(math.sqrt(self.input_shape[-1]//A.patches.shape[-3]))], + A.stride, A.padding) + else: + # sparse: [spec, batch, in_c, patch_H, patch_W] (specs depends on the number of unstable neurons). + patches = A.patches + # expected next_A shape [batch, spec, input_c, in_H, in_W]. + next_A = patches_to_matrix(patches, [self.input_shape[0]*self.input_shape[1], patches.shape[-3], int(math.sqrt(self.input_shape[-1]//patches.shape[-3])), int(math.sqrt(self.input_shape[-1]//patches.shape[-3]))], A.stride, A.padding, output_shape=A.output_shape, unstable_idx=A.unstable_idx) + # Reshape it to [batch, spec, *input_shape] (input_shape is the shape before Reshape operation). + next_A = next_A.reshape(A.shape[1], -1, *self.input_shape[1:]) + return next_A.transpose(0,1) + else: + return A.reshape(A.shape[0], A.shape[1], *self.input_shape[1:]) + #FIXME check reshape or view return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0 def bound_forward(self, dim_in, x, shape): @@ -32,15 +66,9 @@ def bound_forward(self, dim_in, x, shape): return LinearBound(lw, lb, uw, ub) def interval_propagate(self, *v): - if Interval.use_relative_bounds(*v): - return Interval( - None, None, - v[0].nominal.reshape(self.shape), - v[0].lower_offset.reshape(self.shape), - v[0].upper_offset.reshape(self.shape), - ptb=v[0].ptb - ) - return Interval.make_interval(v[0][0].reshape(*v[1][0]), v[0][1].reshape(*v[1][0]), v[0]) + return Interval.make_interval( + self.forward(v[0][0], v[1][0]), + self.forward(v[0][1], v[1][0]), v[0]) def infer_batch_dim(self, batch_size, *x): if x[0] == -1: @@ -51,29 +79,48 @@ def infer_batch_dim(self, batch_size, *x): self.input_shape, self.shape, x[0] )) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + if isinstance(v[0], Tensor): + self.solver_vars = self.forward(*v) + return + gvar_array = np.array(v[0]) + gvar_array = gvar_array.reshape(v[1].detach().cpu().numpy())[0] + self.solver_vars = gvar_array.tolist() + class BoundUnsqueeze(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.axes = attr['axes'] assert (len(self.axes) == 1) self.axes = self.axes[0] self.use_default_ibp = True - @Bound.save_io_shape def forward(self, x): - if self.axes < 0: - self.axes = len(self.input_shape) + self.axes + 1 - return x.unsqueeze(self.axes) + return x.unsqueeze(self.axes) def bound_backward(self, last_lA, last_uA, x): + self.axes = self.make_axis_non_negative(self.axes, 'output') if self.axes == 0: - return last_lA, 0, last_uA, 0 + # TODO: unsqueeze on batch dimension can be problematic. + return [(last_lA, last_uA)], 0, 0 else: - return [(last_lA.squeeze(self.axes + 1) if last_lA is not None else None, - last_uA.squeeze(self.axes + 1) if last_uA is not None else None)], 0, 0 + if type(last_lA) == Patches: + lA = Patches(last_lA.patches.squeeze(self.axes - 5), last_lA.stride, last_lA.padding, last_lA.shape, last_lA.identity, last_lA.unstable_idx, last_lA.output_shape) + elif last_lA is not None: + lA = last_lA.squeeze(self.axes+1) + else: + lA = None + if type(last_uA) == Patches: + uA = Patches(last_uA.patches.squeeze(self.axes - 5), last_uA.stride, last_uA.padding, last_uA.shape, last_uA.identity, last_uA.unstable_idx, last_uA.output_shape) + elif last_uA is not None: + uA = last_uA.squeeze(self.axes+1) + else: + uA = None + return [(lA, uA)], 0, 0 def bound_forward(self, dim_in, x): + self.axes = self.make_axis_non_negative(self.axes, 'output') if len(self.input_shape) == 0: lw, lb = x.lw.unsqueeze(1), x.lb.unsqueeze(0) uw, ub = x.uw.unsqueeze(1), x.ub.unsqueeze(0) @@ -88,19 +135,20 @@ def infer_batch_dim(self, batch_size, *x): elif self.axes > x[0]: return x[0] raise NotImplementedError - + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.forward(v[0]) class BoundSqueeze(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.axes = attr['axes'] assert (len(self.axes) == 1) self.axes = self.axes[0] self.use_default_ibp = True - @Bound.save_io_shape def forward(self, x): - return x.squeeze(self.axes) + return x.squeeze(self.axes) def bound_backward(self, last_lA, last_uA, x): assert (self.axes != 0) @@ -119,12 +167,11 @@ def infer_batch_dim(self, batch_size, *x): class BoundFlatten(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.use_default_ibp = True self.axis = attr['axis'] - @Bound.save_io_shape def forward(self, x): return torch.flatten(x, self.axis) @@ -133,22 +180,42 @@ def _bound_oneside(A): if A is None: return None return A.reshape(A.shape[0], A.shape[1], *self.input_shape[1:]) - return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0 + def bound_dynamic_forward(self, x, max_dim=None, offset=0): + w = torch.flatten(x.lw, self.axis + 1) + b = torch.flatten(x.lb, self.axis) + x_L = torch.flatten(x.x_L, self.axis) + x_U = torch.flatten(x.x_U, self.axis) + return LinearBound(w, b, w, b, x_L=x_L, x_U=x_U, tot_dim=x.tot_dim) + + def bound_forward(self, dim_in, x): + self.axis = self.make_axis_non_negative(self.axis) + assert self.axis > 0 + return LinearBound( + torch.flatten(x.lw, self.axis + 1), + torch.flatten(x.lb, self.axis), + torch.flatten(x.uw, self.axis + 1), + torch.flatten(x.ub, self.axis), + ) + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + # e.g., v[0] input shape (16, 8, 8) => output shape (1024,) + self.solver_vars = np.array(v[0]).reshape(-1).tolist() + model.update() + + class BoundConcat(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.axis = attr['axis'] self.IBP_rets = None - @Bound.save_io_shape def forward(self, *x): # x is a list of tensors - x = [(item if isinstance(item, torch.Tensor) else torch.tensor(item)) for item in x] + x = [(item if isinstance(item, Tensor) else torch.tensor(item)) for item in x] self.input_size = [item.shape[self.axis] for item in x] - if self.axis < 0: - self.axis = x[0].ndim + self.axis + self.axis = self.make_axis_non_negative(self.axis) return torch.cat(x, dim=int(self.axis)) def interval_propagate(self, *v): @@ -169,15 +236,6 @@ def interval_propagate(self, *v): all_inf = all(map(lambda x: x is None or x == np.inf, norms)) all_2 = all(map(lambda x: x is None or x == 2, norms)) - if Interval.use_relative_bounds(*v): - assert all_inf # Only LINF supported for now - return Interval( - None, None, - self.forward(*[_v.nominal for _v in v]), - self.forward(*[_v.lower_offset for _v in v]), - self.forward(*[_v.upper_offset for _v in v]), - ) - h_L = [_v[0] for _v in v] h_U = [_v[1] for _v in v] if all_inf: @@ -195,14 +253,22 @@ def interval_propagate(self, *v): raise RuntimeError("BoundConcat does not support inputs with norm {}".format(norms)) def bound_backward(self, last_lA, last_uA, *x): - if self.axis < 0: - self.axis = len(self.output_shape) + self.axis - assert (self.axis > 0) + self.axis = self.make_axis_non_negative(self.axis, 'output') + assert self.axis > 0 def _bound_oneside(last_A): if last_A is None: return None - return torch.split(last_A, self.input_size, dim=self.axis + 1) + if isinstance(last_A, torch.Tensor): + return torch.split(last_A, self.input_size, dim=self.axis + 1) + elif isinstance(last_A, Patches): + assert len(self.input_shape) == 4 and self.axis == 1, "Split channel dimension is supported; others are unimplemented." + # Patches shape can be [out_c, batch, out_h, out_w, in_c, patch_h, patch_w] + # Or [spec, batch, in_c, patch_h, patch_w] (sparse) + new_patches = torch.split(last_A.patches, self.input_size, dim=-3) # split the in_c dimension is easy. + return [last_A.create_similar(p) for p in new_patches] + else: + raise RuntimeError(f'Unsupported type for last_A: {type(last_A)}') uA = _bound_oneside(last_uA) lA = _bound_oneside(last_lA) @@ -213,8 +279,7 @@ def _bound_oneside(last_A): return [(lA[i], uA[i]) for i in range(len(lA))], 0, 0 def bound_forward(self, dim_in, *x): - if self.axis < 0: - self.axis = x[0].lb.ndim + self.axis + self.axis = self.make_axis_non_negative(self.axis) assert (self.axis == 0 and not self.from_input or self.from_input) lw = torch.cat([item.lw for item in x], dim=self.axis + 1) lb = torch.cat([item.lb for item in x], dim=self.axis) @@ -222,19 +287,25 @@ def bound_forward(self, dim_in, *x): ub = torch.cat([item.ub for item in x], dim=self.axis) return LinearBound(lw, lb, uw, ub) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.forward(*v) + def infer_batch_dim(self, batch_size, *x): assert np.min(x) == np.max(x) assert x[0] != self.axis return x[0] + +BoundConcatFromSequence = BoundConcat + class BoundShape(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) - self.use_default_ibp = True + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) + self.use_default_ibp = True @staticmethod def shape(x): - return x.shape if isinstance(x, torch.Tensor) else torch.tensor(x).shape + return x.shape if isinstance(x, Tensor) else torch.tensor(x).shape def forward(self, x): self.from_input = False @@ -247,59 +318,88 @@ def infer_batch_dim(self, batch_size, *x): return -1 def interval_propagate(self, *v): - if Interval.use_relative_bounds(*v): - shape = self.forward(v[0].nominal) - if not isinstance(shape, torch.Tensor): - shape = torch.tensor(shape, device=self.device) - return Interval( - None, None, - shape, torch.zeros_like(shape), torch.zeros_like(shape) - ) - return super().interval_propagate(*v) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + if not isinstance(v[0], Tensor): + # e.g., v[0] input shape (8, 7, 7) => output its shape (1, 8, 7, 7) + gvars_array = np.array(v[0]) + self.solver_vars = torch.tensor(np.expand_dims(gvars_array, axis=0).shape).long() + else: + self.solver_vars = torch.tensor(self.forward(v[0])).long() + class BoundGather(Bound): - def __init__(self, input_name, name, ori_name, attr, x, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, x, output_index, options, device) + def __init__(self, attr, x, output_index, options): + super().__init__(attr, x, output_index, options) self.axis = attr['axis'] if 'axis' in attr else 0 - self.nonlinear = False # input shape required - @Bound.save_io_shape def forward(self, x, indices): self.indices = indices + if self.axis == -1: + self.axis = len(x.shape) - 1 x = x.to(self.indices.device) # BoundShape.shape() will return value on cpu only if indices.ndim == 0: - # `index_select` requires `indices` to be a 1-D tensor return torch.index_select(x, dim=self.axis, index=indices).squeeze(self.axis) elif self.axis == 0: return torch.index_select(x, dim=self.axis, index=indices.reshape(-1)) \ .reshape(*indices.shape, x.shape[-1]) - raise ValueError( - 'Unsupported shapes in Gather: data {}, indices {}, axis {}'.format(x.shape, indices.shape, self.axis)) + elif self.indices.ndim == 1: + # `index_select` requires `indices` to be a 1-D tensor + return torch.index_select(x, dim=self.axis, index=indices) + + raise ValueError('Unsupported shapes in Gather: data {}, indices {}, axis {}'.format(x.shape, indices.shape, self.axis)) def bound_backward(self, last_lA, last_uA, x, indices): assert self.from_input - assert self.indices.ndim == 0 # TODO - def _bound_oneside(A): - if A is None: - return None - assert (self.indices.ndim == 0) - - A = A.unsqueeze(self.axis + 1) - idx = int(self.indices) + def _expand_A_with_zeros(A, axis, idx, max_axis_size): + # Need to recreate A with three parts: before the gathered element, gathered element, and after gathered element. tensors = [] if idx > 0: shape_pre = list(A.shape) - shape_pre[self.axis + 1] *= idx - tensors.append(torch.zeros(shape_pre, device=self.device)) + shape_pre[axis] *= idx + # Create the same shape as A, except for the dimension to be gathered. + tensors.append(torch.zeros(shape_pre, device=A.device)) + # The gathered element itself, in the middle. tensors.append(A) - if self.input_shape[self.axis] - idx - 1 > 0: + if max_axis_size - idx - 1 > 0: shape_next = list(A.shape) - shape_next[self.axis + 1] *= self.input_shape[self.axis] - idx - 1 - tensors.append(torch.zeros(shape_next, device=self.device)) - return torch.cat(tensors, dim=self.axis + 1) + shape_next[axis] *= max_axis_size - idx - 1 + # Create the rest part of A. + tensors.append(torch.zeros(shape_next, device=A.device)) + # Concatenate all three parts together. + return torch.cat(tensors, dim=axis) + + def _bound_oneside(A): + if A is None: + return None + + if isinstance(A, torch.Tensor): + if self.indices.ndim == 0: + A = A.unsqueeze(self.axis + 1) + idx = int(self.indices) + return _expand_A_with_zeros(A, self.axis + 1, idx, self.input_shape[self.axis]) + else: + shape = list(A.shape) + final_A = torch.zeros(*shape[:self.axis + 1], self.input_shape[self.axis], *shape[self.axis + 2:], device=A.device) + idx = self.indices.view([*[1]*(self.axis+1), -1, *[1]*len(shape[self.axis + 2:])]) + idx = idx.repeat([*A.shape[:self.axis+1], 1, *A.shape[self.axis+2:]]) + final_A.scatter_(dim=self.axis+1, index=idx, src=A) + return final_A + elif isinstance(A, Patches): + if self.indices.ndim == 0: + idx = int(self.indices) + assert len(self.input_shape) == 4 and self.axis == 1, "Gather is only supported on the channel dimension for Patches mode." + # For gather in the channel dimension, we only need to deal with the in_c dimension (-3) in patches. + patches = A.patches + # -3 is the in_c dimension. + new_patches = _expand_A_with_zeros(patches, axis=-3, idx=idx, max_axis_size=self.input_shape[self.axis]) + return A.create_similar(new_patches) + else: + raise NotImplementedError + else: + raise ValueError(f'Unknown last_A type {type(A)}') return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0 @@ -321,15 +421,6 @@ def bound_forward(self, dim_in, x, indices): def interval_propagate(self, *v): assert not self.is_input_perturbed(1) - - if Interval.use_relative_bounds(*v): - return Interval( - None, None, - self.forward(v[0].nominal, v[1].nominal), - self.forward(v[0].lower_offset, v[1].nominal), - self.forward(v[0].upper_offset, v[1].nominal) - ) - return self.forward(v[0][0], v[1][0]), self.forward(v[0][1], v[1][0]) def infer_batch_dim(self, batch_size, *x): @@ -339,27 +430,22 @@ def infer_batch_dim(self, batch_size, *x): else: return x[1] - def infer_batch_dim(self, batch_size, *x): - if x[0] != -1: - assert self.axis != x[0] - return x[0] - else: - return x[1] + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.forward(v[0], v[1]) class BoundGatherElements(Bound): - def __init__(self, input_name, name, ori_name, attr, input, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, input, output_index, options, device) + def __init__(self, attr, input, output_index, options): + super().__init__(attr, input, output_index, options) self.axis = attr['axis'] - @Bound.save_io_shape def forward(self, x, index): self.index = index return torch.gather(x, dim=self.axis, index=index) def bound_backward(self, last_lA, last_uA, x, index): assert self.from_input - + dim = self._get_dim() def _bound_oneside(last_A): @@ -377,15 +463,6 @@ def _bound_oneside(last_A): def interval_propagate(self, *v): assert not self.is_input_perturbed(1) - - if Interval.use_relative_bounds(*v): - return Interval( - None, None, - self.forward(v[0].nominal, v[1].nominal), - self.forward(v[0].lower_offset, v[1].nominal), - self.forward(v[0].upper_offset, v[1].nominal) - ) - return self.forward(v[0][0], v[1][0]), \ self.forward(v[0][1], v[1][1]) @@ -401,7 +478,7 @@ def bound_forward(self, dim_in, x, index): def infer_batch_dim(self, batch_size, *x): assert self.axis != x[0] return x[0] - + def _get_dim(self): dim = self.axis if dim < 0: @@ -409,16 +486,15 @@ def _get_dim(self): return dim class BoundTranspose(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.perm = attr['perm'] self.perm_inv_inc_one = [-1] * (len(self.perm) + 1) self.perm_inv_inc_one[0] = 0 for i in range(len(self.perm)): self.perm_inv_inc_one[self.perm[i] + 1] = i + 1 - self.use_default_ibp = True + self.use_default_ibp = True - @Bound.save_io_shape def forward(self, x): return x.permute(*self.perm) @@ -441,6 +517,9 @@ def bound_forward(self, dim_in, x): return LinearBound(lw, lb, uw, ub) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.forward(*v) + def infer_batch_dim(self, batch_size, *x): if x[0] == -1: return -1 @@ -449,21 +528,14 @@ def infer_batch_dim(self, batch_size, *x): class BoundSlice(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.start = attr["starts"][0] if "starts" in attr else None self.end = attr["ends"][0] if "ends" in attr else None self.axes = attr["axes"][0] if "axes" in attr else None self.use_default_ibp = False - # Older Pytorch version only passes steps as input. - @Bound.save_io_shape - def forward(self, x, start=None, end=None, axes=None, steps=1): - start = self.start if start is None else start - end = self.end if end is None else end - axes = self.axes if axes is None else axes - assert (steps == 1 or steps == -1) and axes == int(axes) and start == int(start) and end == int(end) - shape = x.shape if isinstance(x, torch.Tensor) else [len(x)] + def _fixup_params(self, shape, start, end, axes, steps): if start < 0: start += shape[axes] if end < 0: @@ -471,19 +543,32 @@ def forward(self, x, start=None, end=None, axes=None, steps=1): end = 0 # only possible when step == -1 else: end += shape[axes] - if steps == -1: + if steps == -1: start, end = end, start + 1 # TODO: more test more negative step size. end = min(end, shape[axes]) + return start, end + + # Older Pytorch version only passes steps as input. + def forward(self, x, start=None, end=None, axes=None, steps=1): + start = self.start if start is None else start + end = self.end if end is None else end + axes = self.axes if axes is None else axes + assert (steps == 1 or steps == -1) and axes == int(axes) and start == int(start) and end == int(end) + shape = x.shape if isinstance(x, Tensor) else [len(x)] + start, end = self._fixup_params(shape, start, end, axes, steps) final = torch.narrow(x, dim=int(axes), start=int(start), length=int(end - start)) if steps == -1: final = torch.flip(final, dims=tuple(axes)) return final - - def interval_propagate(self, *v): + + def interval_propagate(self, *v): lb = tuple(map(lambda x:x[0],v)) ub = tuple(map(lambda x:x[1],v)) return Interval.make_interval(self.forward(*lb), self.forward(*ub)) + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.forward(*v) + def infer_batch_dim(self, batch_size, *x): if x[0] == -1: return -1 @@ -491,12 +576,71 @@ def infer_batch_dim(self, batch_size, *x): assert self.axes != x[0] return x[0] + def bound_backward(self, last_lA, last_uA, *x): + def _bound_oneside(A, start, end, axes, steps): + if A is None: + return None + if isinstance(A, torch.Tensor): + # Reuse the batch and spec dimension of A, and replace other shapes with input. + A_shape = A.shape[:2] + self.input_shape[1:] + new_A = torch.zeros(size=A_shape, device=A.device, requires_grad=A.requires_grad) + # Fill part of the new_A based on start, end, axes and steps. + # Skip the spec dimension at the front (axes + 1). + dim = axes if axes < 0 else axes + 1 + indices = torch.arange(start, end, device=A.device) + new_A = torch.index_copy(new_A, dim=dim, index=indices, source=A) + elif isinstance(A, Patches): + assert A.unstable_idx is None + assert len(self.input_shape) == 4 and axes == 1, "Slice is only supported on channel dimension." + patches = A.patches + # patches shape is [out_c, batch, out_h, out_w, in_c, patch_h, patch_w]. + new_patches_shape = patches.shape[:4] + (self.input_shape[1], ) + patches.shape[-2:] + new_patches = torch.zeros(size=new_patches_shape, device=patches.device, requires_grad=patches.requires_grad) + indices = torch.arange(start, end, device=patches.device) + new_patches = torch.index_copy(new_patches, dim=-3, index=indices, source=patches) + # Only the in_c dimension is changed. + new_A = A.create_similar(new_patches) + else: + raise ValueError(f'Unsupport A type {type(A)}') + return new_A + + start, end, axes = x[1].value.item(), x[2].value.item(), x[3].value.item() + steps = x[4].value.item() if len(x) == 5 else 1 # If step is not specified, it is 1. + # Other step size untested, do not enable for now. + assert steps == 1 and axes == int(axes) and start == int(start) and end == int(end) + start, end = self._fixup_params(self.input_shape, start, end, axes, steps) + # Find the original shape of A. + lA = _bound_oneside(last_lA, start, end, axes, steps) + uA = _bound_oneside(last_uA, start, end, axes, steps) + return [(lA, uA), (None, None), (None, None), (None, None), (None, None)], 0, 0 + + def bound_forward(self, dim_in, *inputs): + assert len(inputs) == 5 or len(inputs) == 4 + start = inputs[1].lb.item() + end = inputs[2].lb.item() + axis = self.make_axis_non_negative(inputs[3].lb.item()) + assert axis > 0, "Slicing along the batch dimension is not supported yet" + steps = inputs[4].lb.item() if len(inputs) == 5 else 1 # If step is not specified, it is 1. + assert steps in [1, -1] + x = inputs[0] + shape = x.lb.shape + start, end = self._fixup_params(shape, start, end, axis, steps) + lw = torch.narrow(x.lw, dim=axis+1, start=start, length=end - start) + uw = torch.narrow(x.uw, dim=axis+1, start=start, length=end - start) + lb = torch.narrow(x.lb, dim=axis, start=start, length=end - start) + ub = torch.narrow(x.ub, dim=axis, start=start, length=end - start) + if steps == -1: + lw = torch.flip(lw, dims=tuple(axis+1)) + uw = torch.flip(uw, dims=tuple(axis+1)) + lb = torch.flip(lb, dims=tuple(axis)) + ub = torch.flip(ub, dims=tuple(axis)) + return LinearBound(lw, lb, uw, ub) + class BoundExpand(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) - @Bound.save_io_shape def forward(self, x, y): y = y.clone() assert y.ndim == 1 @@ -518,18 +662,19 @@ def infer_batch_dim(self, batch_size, *x): class BoundSplit(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.axis = attr['axis'] self.split = attr['split'] self.use_default_ibp = True - @Bound.save_io_shape def forward(self, x): + if self.axis == -1: + self.axis = len(x.shape) - 1 return torch.split(x, self.split, dim=self.axis)[self.output_index] def bound_backward(self, last_lA, last_uA, x): - assert (self.axis > 0) + assert self.axis > 0 pre = sum(self.split[:self.output_index]) suc = sum(self.split[(self.output_index + 1):]) diff --git a/auto_LiRPA/operators/softmax.py b/auto_LiRPA/operators/softmax.py index 21ad6b0..c8fdd1c 100644 --- a/auto_LiRPA/operators/softmax.py +++ b/auto_LiRPA/operators/softmax.py @@ -7,7 +7,6 @@ def __init__(self, axis): self.axis = axis assert self.axis == int(self.axis) - @Bound.save_io_shape def forward(self, x): max_x = torch.max(x, dim=self.axis).values x = torch.exp(x - max_x.unsqueeze(self.axis)) @@ -16,8 +15,8 @@ def forward(self, x): # The `option != 'complex'` case is not used in the auto_LiRPA main paper. class BoundSoftmax(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + def __init__(self, attr, inputs, output_index, options): + super().__init__(attr, inputs, output_index, options) self.axis = attr['axis'] self.option = options.get('softmax', 'complex') if self.option == 'complex': @@ -25,7 +24,6 @@ def __init__(self, input_name, name, ori_name, attr, inputs, output_index, optio else: self.max_input = 30 - @Bound.save_io_shape def forward(self, x): assert self.axis == int(self.axis) if self.option == 'complex': diff --git a/auto_LiRPA/operators/solver_utils.py b/auto_LiRPA/operators/solver_utils.py new file mode 100644 index 0000000..79a02b2 --- /dev/null +++ b/auto_LiRPA/operators/solver_utils.py @@ -0,0 +1,11 @@ +class DummyGurobipyClass: + """A dummy class with error message when gurobi is not installed.""" + def __getattr__(self, attr): + def _f(*args, **kwargs): + raise RuntimeError(f"method {attr} not available because gurobipy module was not built.") + return _f + +try: + import gurobipy as grb +except ModuleNotFoundError: + grb = DummyGurobipyClass() \ No newline at end of file diff --git a/auto_LiRPA/optimized_bounds.py b/auto_LiRPA/optimized_bounds.py new file mode 100644 index 0000000..7ec1ffd --- /dev/null +++ b/auto_LiRPA/optimized_bounds.py @@ -0,0 +1,1046 @@ +import time +import os +import warnings +from collections import OrderedDict +from contextlib import ExitStack + +import torch +from torch import optim +from .cuda_utils import double2float + + +def _set_alpha(optimizable_activations, parameters, alphas, lr): + """ + Set best_alphas, alphas and parameters list + """ + for node in optimizable_activations: + alphas.extend(list(node.alpha.values())) + node.opt_start() + # Alpha has shape (2, output_shape, batch_dim, node_shape) + parameters.append({'params': alphas, 'lr': lr, 'batch_dim': 2}) + # best_alpha is a dictionary of dictionary. Each key is the alpha variable + # for one relu layer, and each value is a dictionary contains all relu + # layers after that layer as keys. + best_alphas = OrderedDict() + for m in optimizable_activations: + best_alphas[m.name] = {} + for alpha_m in m.alpha: + best_alphas[m.name][alpha_m] = m.alpha[alpha_m].detach().clone() + # We will directly replace the dictionary for each relu layer after + # optimization, so the saved alpha might not have require_grad=True. + m.alpha[alpha_m].requires_grad_() + + return best_alphas + + +def _set_beta( + self, relus, optimizable_activations, single_node_split, + enable_opt_interm_bounds, betas, opt_coeffs, parameters, + lr_coeffs, opt_bias, lr_beta, lr_cut_beta, cutter, dense_coeffs_mask): + """ + Set betas, best_betas, coeffs, dense_coeffs_mask, best_coeffs, biases + and best_biases. + """ + coeffs = best_coeffs = biases = best_biases = None + if len(relus) != len(optimizable_activations): + warnings.warn( + 'Only relu split is supported so far, this model contains other ' + 'optimizable activations that may not apply split.') + + if single_node_split: + for node in relus: + if enable_opt_interm_bounds and node.sparse_beta is not None: + for key in node.sparse_beta.keys(): + if node.sparse_beta[key] is not None: + betas.append(node.sparse_beta[key]) + else: + if node.sparse_beta is not None: + betas.append(node.sparse_beta) + else: + betas = self.beta_params + self.single_beta_params + if opt_coeffs: + coeffs = [dense_coeffs['dense'] + for dense_coeffs in self.split_dense_coeffs_params + ] + self.coeffs_params + dense_coeffs_mask = [dense_coeffs['mask'] + for dense_coeffs in self.split_dense_coeffs_params] + parameters.append({'params': coeffs, 'lr': lr_coeffs}) + best_coeffs = [coeff.detach().clone() for coeff in coeffs] + if opt_bias: + biases = self.bias_params + parameters.append({'params': biases, 'lr': lr_coeffs}) + best_biases = [bias.detach().clone() for bias in biases] + + # Beta has shape (batch, max_splits_per_layer) + parameters.append({'params': betas, 'lr': lr_beta, 'batch_dim': 0}) + + if self.cut_used: + # also need to optimize cut betas + parameters.append({'params': self.cut_beta_params, + 'lr': lr_cut_beta, 'batch_dim': 0}) + betas = betas + self.cut_beta_params + + if enable_opt_interm_bounds and betas: + best_betas = OrderedDict() + for m in optimizable_activations: + best_betas[m.name] = {} + for beta_m, beta in m.sparse_beta.items(): + best_betas[m.name][beta_m] = beta.detach().clone() + if self.cut_used: + best_betas['cut'] = [] + for general_betas in self.cut_beta_params: + best_betas['cut'].append(general_betas.detach().clone()) + else: + best_betas = [b.detach().clone() for b in betas] + + if self.cut_used and getattr(cutter, 'opt', False): + parameters.append(cutter.get_parameters()) + + return ( + betas, best_betas, coeffs, dense_coeffs_mask, best_coeffs, biases, + best_biases) + + +def _save_ret_first_time(bounds, fill_value, x, best_ret): + """ + Save results at the first iteration to best_ret + """ + if bounds is not None: + best_bounds = torch.full_like( + bounds, fill_value=fill_value, device=x[0].device, dtype=x[0].dtype) + else: + best_bounds = None + + if bounds is not None: + best_ret.append(bounds.detach().clone()) + else: + best_ret.append(None) + + return best_bounds + + +@torch.no_grad() +def _get_preserve_mask( + decision_thresh, ret_l, preserve_mask, multi_spec_keep_func): + """ + Get preserve mask by decision_thresh to filter out the satisfied bounds. + """ + if (isinstance(decision_thresh, torch.Tensor) + and decision_thresh.numel() > 1): + if decision_thresh.shape[-1] == 1: + now_preserve_mask = ( + ret_l <= decision_thresh[preserve_mask] + ).view(-1).nonzero().view(-1) + else: + now_preserve_mask = multi_spec_keep_func( + ret_l <= decision_thresh[preserve_mask]).nonzero().view(-1) + else: + if decision_thresh.shape[-1] == 1: + now_preserve_mask = ( + ret_l <= decision_thresh).view(-1).nonzero().view(-1) + else: + now_preserve_mask = multi_spec_keep_func( + ret_l <= decision_thresh).nonzero().view(-1) + + return now_preserve_mask + + +def _recover_bounds_to_full_batch( + ret, decision_thresh, epsilon_over_decision_thresh, original_size, + preserve_mask, loss_reduction_func): + """ + Recover lower and upper bounds to full batch size so that later we can + directly update using the full batch size of l and u. + """ + if ret is not None: + if (isinstance(decision_thresh, torch.Tensor) + and decision_thresh.numel() > 1): + full_ret = (decision_thresh.clone().to(ret.device).type(ret.dtype) + + epsilon_over_decision_thresh) + else: + num_decision_thresh = decision_thresh + if isinstance(num_decision_thresh, torch.Tensor): + num_decision_thresh = num_decision_thresh.item() + full_ret = torch.full( + (original_size,) + tuple(ret.shape[1:]), + fill_value=num_decision_thresh + epsilon_over_decision_thresh, + device=ret.device, dtype=ret.dtype) + full_ret[preserve_mask] = ret + if full_ret.shape[1] > 1: + full_reduced_ret = loss_reduction_func(full_ret) + else: + full_reduced_ret = full_ret + else: + full_ret = full_reduced_ret = None + + return full_ret, full_reduced_ret + + +@torch.no_grad() +def _prune_bounds_by_mask( + now_preserve_mask, decision_thresh, ret_l, ret_u, ret, preserve_mask, + epsilon_over_decision_thresh, original_size, loss_reduction_func, + beta, intermediate_beta_enabled, + fix_intermediate_layer_bounds, intermediate_layer_bounds, + aux_reference_bounds, partial_intermediate_layer_bounds, + pre_prune_size): + """ + Prune bounds by given now_preserve_mask. + """ + full_ret_l, full_l = _recover_bounds_to_full_batch( + ret_l, decision_thresh, epsilon_over_decision_thresh, + original_size, preserve_mask, loss_reduction_func) + + full_ret_u, full_u = _recover_bounds_to_full_batch( + ret_u, decision_thresh, epsilon_over_decision_thresh, + original_size, preserve_mask, loss_reduction_func) + + full_ret = (full_ret_l, full_ret_u) + ret[2:] + + if beta and intermediate_beta_enabled: + # prune the partial_intermediate_layer_bounds + interval_to_prune = partial_intermediate_layer_bounds + elif fix_intermediate_layer_bounds: + interval_to_prune = intermediate_layer_bounds + else: + interval_to_prune = None + if interval_to_prune is not None: + for k, v in interval_to_prune.items(): + interm_interval_l, interm_interval_r = v[0], v[1] + if interm_interval_l.shape[0] == pre_prune_size: + # the first dim is batch size and matches preserve mask + interm_interval_l = interm_interval_l[now_preserve_mask] + if interm_interval_r.shape[0] == pre_prune_size: + # the first dim is batch size and matches preserve mask + interm_interval_r = interm_interval_r[now_preserve_mask] + interval_to_prune[k] = [interm_interval_l, interm_interval_r] + + if aux_reference_bounds is not None: + for k in aux_reference_bounds: + aux_ref_l, aux_ref_r = aux_reference_bounds[k] + if aux_ref_l.shape[0] == pre_prune_size: + # the first dim is batch size and matches the preserve mask + aux_ref_l = aux_ref_l[now_preserve_mask] + if aux_ref_r.shape[0] == pre_prune_size: + # the first dim is batch size and matches the preserve mask + aux_ref_r = aux_ref_r[now_preserve_mask] + aux_reference_bounds[k] = [aux_ref_l, aux_ref_r] + + # update the global mask here for possible next iteration + preserve_mask_next = preserve_mask[now_preserve_mask] + + return full_l, full_ret_l, full_u, full_ret_u, full_ret, preserve_mask_next + + +@torch.no_grad() +def _prune_x(x, now_preserve_mask): + """ + Prune x by given now_preserve_mask. + """ + x = list(x) + pre_prune_size = x[0].shape[0] + x[0].data = x[0][now_preserve_mask].data + if hasattr(x[0], 'ptb'): + if x[0].ptb.x_L is not None: + x[0].ptb.x_L = x[0].ptb.x_L[now_preserve_mask] + if x[0].ptb.x_U is not None: + x[0].ptb.x_U = x[0].ptb.x_U[now_preserve_mask] + x = tuple(x) + + return x, pre_prune_size + + +def _to_float64(self, C, x, aux_reference_bounds, intermediate_layer_bounds): + """ + Transfer variables to float64 only in the last iteration to help alleviate + floating point error. + """ + self.to(torch.float64) + C = C.to(torch.float64) + x = self._to(x, torch.float64) + # best_intermediate_bounds is linked to aux_reference_bounds! + # we only need call .to() for one of them + self._to(aux_reference_bounds, torch.float64, inplace=True) + intermediate_layer_bounds = self._to( + intermediate_layer_bounds, torch.float64) + + return C, x, intermediate_layer_bounds + + +def _to_default_dtype( + self, x, total_loss, full_ret, ret, best_intermediate_bounds, return_A): + """ + Switch back to default precision from float64 typically to adapt to + afterwards operations. + """ + total_loss = total_loss.to(torch.get_default_dtype()) + self.to(torch.get_default_dtype()) + x[0].to(torch.get_default_dtype()) + full_ret = list(full_ret) + if isinstance(ret[0], torch.Tensor): + # round down lower bound + full_ret[0] = double2float(full_ret[0], 'down') + if isinstance(ret[1], torch.Tensor): + # round up upper bound + full_ret[1] = double2float(full_ret[1], 'up') + for _k, _v in best_intermediate_bounds.items(): + _v[0] = double2float(_v[0], 'down') + _v[1] = double2float(_v[1], 'up') + best_intermediate_bounds[_k] = _v + if return_A: + full_ret[2] = self._to(full_ret[2], torch.get_default_dtype()) + + return total_loss, x, full_ret + + +def _update_best_ret( + full_ret_bound, best_ret_bound, full_ret, best_ret, need_update, idx): + """Update best_ret_bound and best_ret by comparing with new results.""" + assert idx in [0, 1], ( + '0 means updating lower bound, 1 means updating upper bound') + if idx == 0: + idx_mask = (full_ret_bound > best_ret_bound).any(dim=1).view(-1) + else: + idx_mask = (full_ret_bound < best_ret_bound).any(dim=1).view(-1) + + improved_idx = None + if idx_mask.any(): + need_update = True + # we only pick up the results improved in a batch + improved_idx = idx_mask.nonzero(as_tuple=True)[0] + # total_loss = total_loss.to(best_ret_l) + if idx == 0: + best_ret_bound[improved_idx] = torch.maximum( + full_ret_bound[improved_idx], best_ret_bound[improved_idx]) + if full_ret[idx] is not None: + best_ret[idx][improved_idx] = torch.maximum( + full_ret[idx][improved_idx], best_ret[idx][improved_idx]) + else: + best_ret_bound[improved_idx] = torch.minimum( + full_ret_bound[improved_idx], best_ret_bound[improved_idx]) + if full_ret[idx] is not None: + best_ret[idx][improved_idx] = torch.minimum( + full_ret[idx][improved_idx], best_ret[idx][improved_idx]) + + return best_ret_bound, best_ret, idx_mask, improved_idx, need_update + + +def _update_optimizable_activations( + optimizable_activations, pruning_in_iteration, + intermediate_layer_bounds, fix_intermediate_layer_bounds, + best_intermediate_bounds, idx, local_idx, alpha, best_alphas): + """ + Update bounds and alpha of optimizable_activations. + """ + for node in optimizable_activations: + reference_idx = local_idx if pruning_in_iteration else idx + # Update best intermediate layer bounds only when they are optimized. + # If they are already fixed in intermediate_layer_bounds, then do + # nothing. + if (intermediate_layer_bounds is None + or node.inputs[0].name not in intermediate_layer_bounds + or not fix_intermediate_layer_bounds): + best_intermediate_bounds[node.name][0][idx] = torch.max( + best_intermediate_bounds[node.name][0][idx], + node.inputs[0].lower[reference_idx]) + best_intermediate_bounds[node.name][1][idx] = torch.min( + best_intermediate_bounds[node.name][1][idx], + node.inputs[0].upper[reference_idx]) + + if alpha: + # Each alpha has shape (2, output_shape, batch, *shape) for ReLU. + # For other activation function this can be different. + for alpha_m in node.alpha: + if node.alpha_batch_dim == 2: + best_alphas[node.name][alpha_m][:, :, + idx] = node.alpha[alpha_m][:, :, idx] + elif node.alpha_batch_dim == 3: + best_alphas[node.name][alpha_m][:, :, :, + idx] = node.alpha[alpha_m][:, :, :, idx] + else: + raise ValueError( + f'alpha_batch_dim={node.alpha_batch_dim} must be set ' + 'to 2 or 3 in BoundOptimizableActivation') + + +def _update_best_beta( + self, enable_opt_interm_bounds, betas, optimizable_activations, + best_betas, idx): + """ + Update best beta by given idx. + """ + if enable_opt_interm_bounds and betas: + for node in optimizable_activations: + for key in node.sparse_beta.keys(): + best_betas[node.name][key] = ( + node.sparse_beta[key].detach().clone()) + if self.cut_used: + for gbidx, general_betas in enumerate(self.cut_beta_params): + best_betas['cut'][gbidx] = general_betas.detach().clone() + else: + if self.cut_used: + regular_beta_length = len(betas) - len(self.cut_beta_params) + for beta_idx in range(regular_beta_length): + # regular beta crown betas + best_betas[beta_idx][idx] = betas[beta_idx][idx] + for cut_beta_idx in range(len(self.cut_beta_params)): + # general cut beta crown general_betas + best_betas[regular_beta_length + cut_beta_idx][:, :, idx, + :] = betas[regular_beta_length + cut_beta_idx][:, :, idx, :] + else: + for beta_idx in range(len(betas)): + # regular beta crown betas + best_betas[beta_idx][idx] = betas[beta_idx][idx] + + +def get_optimized_bounds( + self, x=None, aux=None, C=None, IBP=False, forward=False, + method='backward', bound_lower=True, bound_upper=False, + reuse_ibp=False, return_A=False, average_A=False, final_node_name=None, + intermediate_layer_bounds=None, reference_bounds=None, + aux_reference_bounds=None, needed_A_dict=None, cutter=None, + decision_thresh=None, epsilon_over_decision_thresh=1e-4): + """ + Optimize CROWN lower/upper bounds by alpha and/or beta. + """ + + opts = self.bound_opts['optimize_bound_args'] + iteration = opts['iteration'] + beta = opts['enable_beta_crown'] + alpha = opts['enable_alpha_crown'] + opt_coeffs = opts['opt_coeffs'] + opt_bias = opts['opt_bias'] + opt_choice = opts['optimizer'] + single_node_split = opts['single_node_split'] + assert single_node_split is True + keep_best = opts['keep_best'] + fix_intermediate_layer_bounds = opts['fix_intermediate_layer_bounds'] + init_alpha = opts['init_alpha'] + lr_alpha = opts['lr_alpha'] + lr_beta = opts['lr_beta'] + lr_cut_beta = opts['lr_cut_beta'] + lr_intermediate_beta = opts['lr_intermediate_beta'] + lr_decay = opts['lr_decay'] + lr_coeffs = opts['lr_coeffs'] + loss_reduction_func = opts['loss_reduction_func'] + stop_criterion_func = opts['stop_criterion_func'] + use_float64_in_last_iteration = opts['use_float64_in_last_iteration'] + early_stop_patience = opts['early_stop_patience'] + intermediate_beta_enabled = opts['intermediate_beta'] + start_save_best = opts['start_save_best'] + multi_spec_keep_func = opts['multi_spec_keep_func'] + enable_opt_interm_bounds = self.bound_opts.get( + 'enable_opt_interm_bounds', False) + sparse_intermediate_bounds = self.bound_opts.get( + 'sparse_intermediate_bounds', False) + verbosity = self.bound_opts['verbosity'] + + assert bound_lower != bound_upper, ( + 'we can only optimize lower OR upper bound at one time') + assert alpha or beta, ( + 'nothing to optimize, use compute bound instead!') + + if C is not None: + self.final_shape = C.size()[:2] + self.bound_opts.update({'final_shape': self.final_shape}) + if init_alpha: + # TODO: this should set up aux_reference_bounds. + self.init_slope(x, share_slopes=opts['use_shared_alpha'], + method=method, c=C, final_node_name=final_node_name) + + # Optimizable activations that are actually used and perturbed + optimizable_activations = [ + n for n in self.optimizable_activations if n.used and n.perturbed] + # Relu node that are actually used + relus = [n for n in self.relus if n.used and n.perturbed] + + alphas, betas, parameters = [], [], [] + dense_coeffs_mask = [] + partial_intermediate_layer_bounds = None + + if alpha: + best_alphas = _set_alpha( + optimizable_activations, parameters, alphas, lr_alpha) + + if beta: + ret_set_beta = _set_beta( + self, relus, optimizable_activations, single_node_split, + enable_opt_interm_bounds, betas, opt_coeffs, parameters, + lr_coeffs, opt_bias, lr_beta, lr_cut_beta, cutter, + dense_coeffs_mask) + betas, best_betas, coeffs = ret_set_beta[:3] + dense_coeffs_mask, best_coeffs, biases, best_biases = ret_set_beta[3:] + + start = time.time() + + if (decision_thresh is not None + and isinstance(decision_thresh, torch.Tensor)): + if decision_thresh.dim() == 1: + # add the spec dim to be aligned with compute_bounds return + decision_thresh = decision_thresh.unsqueeze(-1) + + + if opt_choice == 'adam-autolr': + opt = AdamElementLR(parameters) + elif opt_choice == 'adam': + opt = optim.Adam(parameters) + elif opt_choice == 'sgd': + opt = optim.SGD(parameters, momentum=0.9) + else: + raise NotImplementedError(opt_choice) + + # Create a weight vector to scale learning rate. + loss_weight = torch.ones(size=(x[0].size(0),), device=x[0].device) + scheduler = optim.lr_scheduler.ExponentialLR(opt, lr_decay) + + if verbosity > 0 and intermediate_beta_enabled: + self.print_optimized_beta(relus, intermediate_beta_enabled=True) + + # best_intermediate_bounds is linked to aux_reference_bounds! + best_intermediate_bounds = {} + if (sparse_intermediate_bounds and aux_reference_bounds is None + and reference_bounds is not None): + aux_reference_bounds = {} + for name, (lb, ub) in reference_bounds.items(): + aux_reference_bounds[name] = [ + lb.detach().clone(), ub.detach().clone()] + if aux_reference_bounds is None: + aux_reference_bounds = {} + + with torch.no_grad(): + pruning_in_iteration = False + # for computing the positive domain ratio + original_size = x[0].shape[0] + preserve_mask = None + + # record the overhead due to extra operations from pruning-in-iteration + pruning_time = 0. + + need_grad = True + patience = 0 + for i in range(iteration): + if cutter: + # cuts may be optimized by cutter + self.cut_module = cutter.cut_module + + intermediate_constr = None + + if not fix_intermediate_layer_bounds: + # If we still optimize all intermediate neurons, we can use + # intermediate_layer_bounds as reference bounds. + reference_bounds = intermediate_layer_bounds + + if i == iteration - 1: + # No grad update needed for the last iteration + need_grad = False + + if (self.device == 'cuda' + and torch.get_default_dtype() == torch.float32 + and use_float64_in_last_iteration): + C, x, intermediate_layer_bounds = _to_float64( + self, C, x, aux_reference_bounds, intermediate_layer_bounds) + + # we will use last update preserve mask in caller functions to recover + # lA, l, u, etc to full batch size + self.last_update_preserve_mask = preserve_mask + with torch.no_grad() if not need_grad else ExitStack(): + # ret is lb, ub or lb, ub, A_dict (if return_A is set to true) + + # argument for intermediate_layer_bounds + # If we set neuron bounds individually, or if we are optimizing + # intermediate layer bounds using beta, we do not set + # intermediate_layer_bounds. + # When intermediate betas are used, we must set + # intermediate_layer_bounds to None because we want to recompute + # all intermediate layer bounds. + if beta and intermediate_beta_enabled: + arg_ilb = partial_intermediate_layer_bounds + elif fix_intermediate_layer_bounds: + arg_ilb = intermediate_layer_bounds + else: + arg_ilb = None + + # argument for aux_reference_bounds + if sparse_intermediate_bounds: + arg_arb = aux_reference_bounds + else: + arg_arb = None + + ret = self.compute_bounds( + x, aux, C, method=method, IBP=IBP, forward=forward, + bound_lower=bound_lower, bound_upper=bound_upper, + reuse_ibp=reuse_ibp, return_A=return_A, + final_node_name=final_node_name, average_A=average_A, + intermediate_layer_bounds=arg_ilb, + # This is the currently tightest interval, which will be used to + # pass split constraints when intermediate betas are used. + reference_bounds=reference_bounds, + # This is the interval used for checking for unstable neurons. + aux_reference_bounds=arg_arb, + # These are intermediate layer beta variables and their + # corresponding A matrices and biases. + intermediate_constr=intermediate_constr, + needed_A_dict=needed_A_dict, + update_mask=preserve_mask) + + ret_l, ret_u = ret[0], ret[1] + + if (self.cut_used and i % cutter.log_interval == 0 + and len(self.cut_beta_params) > 0): + # betas[-1]: (2(0 lower, 1 upper), spec, batch, num_constrs) + if ret_l is not None: + print( + i, 'lb beta sum:', + f'{self.cut_beta_params[-1][0].sum() / ret_l.size(0)},', + f'worst {ret_l.min()}') + if ret_u is not None: + print( + i, 'lb beta sum:', + f'{self.cut_beta_params[-1][1].sum() / ret_u.size(0)},', + f'worst {ret_u.min()}') + + if i == 0: + # save results at the first iteration + best_ret = [] + best_ret_l = _save_ret_first_time( + ret[0], float('-inf'), x, best_ret) + best_ret_u = _save_ret_first_time( + ret[1], float('inf'), x, best_ret) + + for node in optimizable_activations: + new_intermediate = [ + node.inputs[0].lower.detach().clone(), + node.inputs[0].upper.detach().clone()] + best_intermediate_bounds[node.name] = new_intermediate + if sparse_intermediate_bounds: + # Always using the best bounds so far as the reference + # bounds. + aux_reference_bounds[node.inputs[0].name] = new_intermediate + + l = ret_l + # Reduction over the spec dimension. + if ret_l is not None and ret_l.shape[1] != 1: + l = loss_reduction_func(ret_l) + u = ret_u + if ret_u is not None and ret_u.shape[1] != 1: + u = loss_reduction_func(ret_u) + + # full_l, full_ret_l and full_u, full_ret_u is used for update the best + full_ret_l, full_ret_u = ret_l, ret_u + full_l = l + full_ret = ret + + # positive domains may already be filtered out, so we use all domains - + # negative domains to compute + if decision_thresh is not None: + if (isinstance(decision_thresh, torch.Tensor) + and decision_thresh.numel() > 1 + and preserve_mask is not None): + if decision_thresh.shape[-1] == 1: + # single spec with pruned domains + negative_domain = ( + ret_l.view(-1) + <= decision_thresh[preserve_mask].view(-1)).sum() + else: + # multiple spec with pruned domains + negative_domain = multi_spec_keep_func( + ret_l <= decision_thresh[preserve_mask]).sum() + else: + if ret_l.shape[-1] == 1: + # single spec + negative_domain = ( + ret_l.view(-1) <= decision_thresh.view(-1)).sum() + else: + # multiple spec + negative_domain = multi_spec_keep_func( + ret_l <= decision_thresh).sum() + positive_domain_num = original_size - negative_domain + else: + positive_domain_num = -1 + positive_domain_ratio = float( + positive_domain_num) / float(original_size) + # threshold is 10% by default + next_iter_pruning_in_iteration = ( + opts['pruning_in_iteration'] and decision_thresh is not None + and positive_domain_ratio > opts['pruning_in_iteration_threshold']) + + if pruning_in_iteration: + stime = time.time() + if return_A: + raise Exception( + 'Pruning in iteration optimization does not support ' + 'return A yet. ' + 'Please fix or discard this optimization by setting ' + '--disable_pruning_in_iteration ' + 'or bab: pruning_in_iteration: false') + now_preserve_mask = _get_preserve_mask( + decision_thresh, ret_l, preserve_mask, multi_spec_keep_func) + # prune C + if C is not None and C.shape[0] == x[0].shape[0]: + C = C[now_preserve_mask] # means C is also batch specific + # prune x + x, pre_prune_size = _prune_x(x, now_preserve_mask) + # prune bounds + ret_prune = _prune_bounds_by_mask( + now_preserve_mask, decision_thresh, ret_l, ret_u, ret, + preserve_mask, epsilon_over_decision_thresh, original_size, + loss_reduction_func, beta, intermediate_beta_enabled, + fix_intermediate_layer_bounds, + intermediate_layer_bounds, aux_reference_bounds, + partial_intermediate_layer_bounds, pre_prune_size) + full_l, full_ret_l = ret_prune[:2] + # ret_prune[2] is full_u which is unused + full_ret_u, full_ret, preserve_mask_next = ret_prune[3:] + pruning_time += time.time() - stime + + loss_ = l if bound_lower else -u + stop_criterion = stop_criterion_func( + full_ret_l) if bound_lower else stop_criterion_func(-full_ret_u) + if (type(stop_criterion) != bool + and stop_criterion.numel() > 1 and pruning_in_iteration): + stop_criterion = stop_criterion[preserve_mask] + total_loss = -1 * loss_ + if type(stop_criterion) == bool: + loss = total_loss.sum() * (not stop_criterion) + else: + loss = (total_loss * stop_criterion.logical_not()).sum() + + + stop_criterion_final = isinstance( + stop_criterion, torch.Tensor) and stop_criterion.all() + + if i == iteration - 1: + best_ret = list(best_ret) + if best_ret[0] is not None: + best_ret[0] = best_ret[0].to(torch.get_default_dtype()) + if best_ret[1] is not None: + best_ret[1] = best_ret[1].to(torch.get_default_dtype()) + + if (i == iteration - 1 and self.device == 'cuda' + and torch.get_default_dtype() == torch.float32 + and use_float64_in_last_iteration): + total_loss, x, full_ret = _to_default_dtype( + self, x, total_loss, full_ret, ret, best_intermediate_bounds, + return_A) + + with torch.no_grad(): + # for lb and ub, we update them in every iteration since updating + # them is cheap + need_update = False + if keep_best: + if best_ret_u is not None: + ret_upd = _update_best_ret( + full_ret_u, best_ret_u, full_ret, best_ret, need_update, + idx=1) + best_ret_u, best_ret, idx_mask, idx, need_update = ret_upd + if best_ret_l is not None: + ret_upd = _update_best_ret( + full_ret_l, best_ret_l, full_ret, best_ret, need_update, + idx=0) + best_ret_l, best_ret, idx_mask, idx, need_update = ret_upd + else: + # Not saving the best, just keep the last iteration. + if full_ret[0] is not None: + best_ret[0] = full_ret[0] + if full_ret[1] is not None: + best_ret[1] = full_ret[1] + if return_A: + # FIXME: A should also be updated by idx. + best_ret = [best_ret[0], best_ret[1], full_ret[2]] + + # Save variables if this is the best iteration. + # To save computational cost, we only check keep_best at the first + # (in case divergence) and second half iterations + # or before early stop by either stop_criterion or + # early_stop_patience reached + # if i < 1 or i > iteration / 2 or stop_criterion_final or + # patience == early_stop_patience: + if (i < 1 or i > int(iteration * start_save_best) + or stop_criterion_final or patience == early_stop_patience): + if need_update: + patience = 0 # bounds improved, reset patience + local_idx = None + # for update propose, we condition the idx to update only + # on domains preserved + if pruning_in_iteration: + # local sparse index of preserved samples where + # idx = true + local_idx = idx_mask[preserve_mask].nonzero().view(-1) + # idx is global sparse index of preserved samples where + # idx = true + new_idx = torch.zeros_like( + idx_mask, dtype=torch.bool, device=idx.device) + new_idx[preserve_mask] = idx_mask[preserve_mask] + idx = new_idx.nonzero().view(-1) + + _update_optimizable_activations( + optimizable_activations, pruning_in_iteration, + intermediate_layer_bounds, + fix_intermediate_layer_bounds, + best_intermediate_bounds, idx, local_idx, alpha, + best_alphas) + + if beta and single_node_split: + _update_best_beta( + self, enable_opt_interm_bounds, betas, + optimizable_activations, best_betas, idx) + + else: + patience += 1 + + if os.environ.get('AUTOLIRPA_DEBUG_OPT', False): + print(f'****** iter [{i}]', + f'loss: {loss.item()}, lr: {opt.param_groups[0]["lr"]}') + + if stop_criterion_final: + print(f'\nall verified at {i}th iter') + break + + if patience > early_stop_patience: + print( + f'Early stop at {i}th iter due to {early_stop_patience}' + ' iterations no improvement!') + break + + current_lr = [param_group['lr'] for param_group in opt.param_groups] + + opt.zero_grad(set_to_none=True) + + if verbosity > 2: + print( + f'*** iter [{i}]\n', f'loss: {loss.item()}', + total_loss.squeeze().detach().cpu().numpy(), 'lr: ', + current_lr) + if beta: + self.print_optimized_beta(relus, intermediate_beta_enabled) + if opt_coeffs: + for co in coeffs: + print(f'coeff sum: {co.abs().sum():.5g}') + if beta and i == 0 and verbosity > 2: + breakpoint() + + if i != iteration - 1: + # we do not need to update parameters in the last step since the + # best result already obtained + loss.backward() + # All intermediate variables are not needed at this point. + self._clear_and_set_new(None) + if opt_choice == 'adam-autolr': + opt.step(lr_scale=[loss_weight, loss_weight]) + else: + opt.step() + + if beta: + # Clipping to >=0. + for b in betas: + b.data = (b >= 0) * b.data + for dmi in range(len(dense_coeffs_mask)): + # apply dense mask to the dense split coeffs matrix + coeffs[dmi].data = ( + dense_coeffs_mask[dmi].float() * coeffs[dmi].data) + + + if alpha: + for m in optimizable_activations: + m.clip_alpha_() + + scheduler.step() + + if pruning_in_iteration: + preserve_mask = preserve_mask_next + if not pruning_in_iteration and next_iter_pruning_in_iteration: + # init preserve_mask etc + preserve_mask = torch.arange( + 0, x[0].shape[0], device=x[0].device, dtype=torch.long) + pruning_in_iteration = True + + if pruning_in_iteration: + # overwrite pruned cells in best_ret by threshold + eps + if return_A: + fin_l, fin_u, fin_A = best_ret + else: + fin_l, fin_u = best_ret + fin_A = None + if fin_l is not None: + new_fin_l = full_ret_l + new_fin_l[preserve_mask] = fin_l[preserve_mask] + fin_l = new_fin_l + if fin_u is not None: + new_fin_u = full_ret_u + new_fin_u[preserve_mask] = fin_u[preserve_mask] + fin_u = new_fin_u + if return_A: + best_ret = (fin_l, fin_u, fin_A) + else: + best_ret = (fin_l, fin_u) + + if verbosity > 3: + breakpoint() + + if keep_best: + def update_best(dest, src): + for item_dest, item_src in zip(dest, src): + if enable_opt_interm_bounds: + for key in item_dest.keys(): + item_dest[key].data = item_src[key].data + else: + item_dest.data = item_src.data + + # Set all variables to their saved best values. + with torch.no_grad(): + for idx, node in enumerate(optimizable_activations): + if alpha: + # Assigns a new dictionary. + node.alpha = best_alphas[node.name] + # Update best intermediate layer bounds only when they are + # optimized. If they are already fixed in + # intermediate_layer_bounds, then do nothing. + best_intermediate = best_intermediate_bounds[node.name] + node.inputs[0].lower.data = best_intermediate[0].data + node.inputs[0].upper.data = best_intermediate[1].data + if beta: + if (single_node_split and hasattr(node, 'sparse_beta') + and node.sparse_beta is not None): + if enable_opt_interm_bounds: + for key in node.sparse_beta.keys(): + node.sparse_beta[key].copy_( + best_betas[node.name][key]) + else: + node.sparse_beta.copy_(best_betas[idx]) + else: + update_best(betas, best_betas) + if opt_coeffs: + update_best(coeffs, best_coeffs) + if opt_bias: + update_best(biases, best_biases) + if self.cut_used: + regular_beta_length = len(betas) - len(self.cut_beta_params) + for ii in range(len(self.cut_beta_params)): + self.cut_beta_params[ii].data = best_betas[ + regular_beta_length + ii].data + + if (intermediate_layer_bounds is not None + and not fix_intermediate_layer_bounds): + for l in self._modules.values(): + if (l.name in intermediate_layer_bounds.keys() + and hasattr(l, 'lower')): + l.lower = torch.max( + l.lower, intermediate_layer_bounds[l.name][0]) + l.upper = torch.min( + l.upper, intermediate_layer_bounds[l.name][1]) + infeasible_neurons = l.lower > l.upper + if infeasible_neurons.any(): + print( + f'Infeasibility detected in layer {l.name}.', + infeasible_neurons.sum().item(), + infeasible_neurons.nonzero()[:, 0]) + + if verbosity > 0: + if self.cut_used and beta: + print( + 'first 10 best general betas:', + best_betas[-1].view(2, -1)[0][:10], 'sum:', + best_betas[-1][0].sum().item()) + if best_ret_l is not None: + # FIXME: unify the handling of l and u. + print( + 'best_l after optimization:', + best_ret_l.sum().item(), 'with beta sum per layer:', + [p.sum().item() for p in betas]) + print('alpha/beta optimization time:', time.time() - start) + + for node in optimizable_activations: + node.opt_end() + + # update pruning ratio + if (opts['pruning_in_iteration'] and decision_thresh is not None + and full_l.numel() > 0): + stime = time.time() + with torch.no_grad(): + if isinstance(decision_thresh, torch.Tensor): + if decision_thresh.shape[-1] == 1: + neg_domain_num = torch.sum( + full_ret_l.view(-1) <= decision_thresh.view(-1)).item() + else: + neg_domain_num = torch.sum(multi_spec_keep_func( + full_ret_l <= decision_thresh)).item() + else: + if full_l.shape[-1] == 1: + neg_domain_num = torch.sum( + full_ret_l.view(-1) <= decision_thresh).item() + else: + neg_domain_num = torch.sum(multi_spec_keep_func( + full_ret_l <= decision_thresh)).item() + now_pruning_ratio = (1.0 - + float(neg_domain_num) / float(full_l.shape[0])) + print('pruning_in_iteration open status:', pruning_in_iteration) + print( + 'ratio of positive domain =', full_l.shape[0] - neg_domain_num, + '/', full_l.numel(), '=', now_pruning_ratio) + pruning_time += time.time() - stime + print('pruning-in-iteration extra time:', pruning_time) + + return best_ret + + +def init_slope( + self, x, share_slopes=False, method='backward', + c=None, bound_lower=True, bound_upper=True, final_node_name=None, + intermediate_layer_bounds=None, activation_opt_params=None, + skip_bound_compute=False): + for node in self.optimizable_activations: + # initialize the parameters + node.opt_init() + + if (not skip_bound_compute or intermediate_layer_bounds is None or + activation_opt_params is None or not all( + [relu.name in activation_opt_params for relu in self.relus])): + skipped = False + # if new interval is None, then CROWN interval is not present + # in this case, we still need to redo a CROWN pass to initialize + # lower/upper + with torch.no_grad(): + l, u = self.compute_bounds( + x=x, C=c, method=method, bound_lower=bound_lower, + bound_upper=bound_upper, final_node_name=final_node_name, + intermediate_layer_bounds=intermediate_layer_bounds) + else: + # we skip, but we still would like to figure out the "used", + # "perturbed", "backward_from" of each note in the graph + skipped = True + # this set the "perturbed" property + self._set_input( + *x, intermediate_layer_bounds=intermediate_layer_bounds) + + final = self.final_node( + ) if final_node_name is None else self[final_node_name] + self._set_used_nodes(final) + + self.backward_from = {node: [final] for node in self._modules} + + final_node_name = final_node_name or self.final_name + + init_intermediate_bounds = {} + for node in self.optimizable_activations: + if not node.used or not node.perturbed: + continue + start_nodes = [] + if method in ['forward', 'forward+backward']: + start_nodes.append(('_forward', 1, None)) + if method in ['backward', 'forward+backward']: + start_nodes += self.get_alpha_crown_start_nodes( + node, c=c, share_slopes=share_slopes, + final_node_name=final_node_name) + if skipped: + node.restore_optimized_params(activation_opt_params[node.name]) + else: + node.init_opt_parameters(start_nodes) + init_intermediate_bounds[node.inputs[0].name] = ( + [node.inputs[0].lower.detach(), node.inputs[0].upper.detach()]) + + if self.bound_opts['verbosity'] >= 1: + print('Optimizable variables initialized.') + if skip_bound_compute: + return init_intermediate_bounds + else: + return l, u, init_intermediate_bounds diff --git a/auto_LiRPA/parse_graph.py b/auto_LiRPA/parse_graph.py index a2d054f..3137b7c 100644 --- a/auto_LiRPA/parse_graph.py +++ b/auto_LiRPA/parse_graph.py @@ -1,15 +1,15 @@ import os import torch +from torch.onnx.utils import _optimize_graph +from torch.onnx.symbolic_helper import _set_opset_version from collections import OrderedDict -import re from collections import namedtuple -from torch.onnx import OperatorExportTypes -from packaging import version +import re from .bounded_tensor import BoundedTensor, BoundedParameter from .utils import logger, unpack_inputs Node = namedtuple('Node', ( - 'name', 'ori_name', 'inputs', 'attr', 'op', 'param', 'input_index', + 'name', 'ori_name', 'inputs', 'attr', 'op', 'param', 'input_index', 'bound_node', 'output_index', 'perturbation'), defaults=(None,) * 10) def get_node_name(node): @@ -17,14 +17,14 @@ def get_node_name(node): def parse_graph(graph, inputs, params): input_all = [] - input_used = [] + input_used = [] scope = {} for n in graph.inputs(): input_all.append(n.debugName()) for n in graph.nodes(): n_inputs = [get_node_name(i) for i in n.inputs()] for inp in n.inputs(): - input_used.append(inp.debugName()) + input_used.append(inp.debugName()) for out in n.outputs(): scope[get_node_name(out)] = n.scopeName() for node in graph.inputs(): @@ -39,17 +39,18 @@ def parse_graph(graph, inputs, params): def name_with_scope(node): name = get_node_name(node) return '/'.join([scope[name], name]) - + nodesOP = [] for n in graph.nodes(): attrs = {k: n[k] for k in n.attributeNames()} n_inputs = [name_with_scope(i) for i in n.inputs()] for i, out in enumerate(list(n.outputs())): + nodesOP.append(Node(**{'name': name_with_scope(out), 'op': n.kind(), 'inputs': n_inputs, 'attr': attrs, - 'output_index': i, + 'output_index': i, })) # filter out input nodes in `graph.inputs()` that are actually used @@ -64,7 +65,7 @@ def name_with_scope(node): # filter out input nodes in `inputs` that are actually used inputs_unpacked = unpack_inputs(inputs) assert len(list(graph.inputs())) == len(inputs_unpacked) + len(params) - inputs = [inputs_unpacked[i] for i in range(len(inputs_unpacked)) if used_by_index[i]] + inputs = [inputs_unpacked[i] for i in range(len(inputs_unpacked)) if used_by_index[i]] # index of the used inputs among all the inputs input_index = [i for i in range(len(inputs_unpacked)) if used_by_index[i]] # Add a name to all inputs @@ -72,7 +73,7 @@ def name_with_scope(node): # filter out params that are actually used params = [params[i] for i in range(len(params)) if used_by_index[i + len(inputs_unpacked)]] inputs_and_params = inputs + params - assert len(nodesIn) == len(inputs_and_params) + assert len(nodesIn) == len(inputs_and_params) # output nodes of the module nodesOut = [] @@ -81,7 +82,7 @@ def name_with_scope(node): nodesOut.append(name_with_scope(n)) for i, n in enumerate(nodesIn): - if (isinstance(inputs_and_params[i][1], BoundedTensor) or + if (isinstance(inputs_and_params[i][1], BoundedTensor) or isinstance(inputs_and_params[i][1], BoundedParameter)): perturbation = inputs_and_params[i][1].ptb else: @@ -92,10 +93,10 @@ def name_with_scope(node): nodesIn[i] = Node(**{'name': name_with_scope(n), 'ori_name': inputs_and_params[i][0], 'op': 'Parameter', - 'inputs': [], + 'inputs': [], 'attr': str(n.type()), 'param': inputs_and_params[i][1] if i >= len(inputs) else None, - # index among all the inputs including unused ones + # index among all the inputs including unused ones 'input_index': input_index[i] if i < len(inputs) else None, # Input nodes may have perturbation, if they are wrapped in BoundedTensor or BoundedParameters 'perturbation': perturbation, }) @@ -123,7 +124,7 @@ def _get_jit_params(module, param_exclude, param_include): return params -"""Construct a template for the module output with `None` representing places +"""Construct a template for the module output with `None` representing places to be filled with tensor results""" def get_output_template(out): if isinstance(out, torch.Tensor): @@ -142,12 +143,10 @@ def get_output_template(out): def parse_module(module, inputs, param_exclude=".*AuxLogits.*", param_include=None): params = _get_jit_params(module, param_exclude=param_exclude, param_include=param_include) - # PyTorch>=1.5 is required here trace, out = torch.jit._get_trace_graph(module, inputs) - from torch.onnx.symbolic_helper import _set_opset_version _set_opset_version(12) - trace_graph = torch.onnx.utils._optimize_graph(trace, torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, params_dict={}) - + trace_graph = _optimize_graph( + trace, torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, params_dict={}) logger.debug('trace_graph: {}'.format(trace_graph)) if int(os.environ.get('AUTOLIRPA_DEBUG_GRAPH', 0)) > 0: @@ -158,7 +157,7 @@ def parse_module(module, inputs, param_exclude=".*AuxLogits.*", param_include=No if not isinstance(inputs, tuple): inputs = (inputs, ) - + nodesOP, nodesIn, nodesOut = parse_graph(trace_graph, tuple(inputs), tuple(params)) for i in range(len(nodesOP)): diff --git a/auto_LiRPA/patches.py b/auto_LiRPA/patches.py new file mode 100644 index 0000000..87a7cbe --- /dev/null +++ b/auto_LiRPA/patches.py @@ -0,0 +1,502 @@ +import torch +import torch.nn.functional as F +from torch import Tensor + + +def insert_zeros(image, s): + """ + Insert s columns and rows 0 between every pixel in the image. For example: + image = [[1, 2, 3], + [4, 5, 6], + [7, 8, 9]] + s = 2 + output = [[1, 0, 0, 2, 0, 0, 3], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [4, 0, 0, 5, 0, 0, 6], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [7, 0, 0, 8, 0, 0, 9]] + """ + if s <= 0: + return image + matrix = torch.zeros(size=(image.size(0), image.size(1), image.size(2) * (s+1) - s, image.size(3) * (s+1) - s), dtype=image.dtype, device=image.device) + matrix_stride = matrix.stride() + selected_matrix = torch.as_strided(matrix, [ + # Shape of the output matrix. + matrix.size(0), # Batch size. + matrix.size(1), # Channel. + image.size(2), # H (without zeros) + image.size(3), # W (without zeros) + ], [ + # Stride of the output matrix. + matrix_stride[0], # Batch size dimension, keep using the old stride. + matrix_stride[1], # Channel dimension. + matrix_stride[2] * (s + 1), # Move s+1 rows. + s+1, # Move s+1 pixels. + ]) # Move a pixel (on the width direction). + selected_matrix[:] = image + return matrix + + +def remove_zeros(image, s): + if s <= 0: + return image + matrix_stride = image.stride() + return torch.as_strided(image, [ + # Shape of the output matrix. + *image.shape[:-2], + (image.size(-2) + 1) // 2, # H (without zeros) + (image.size(-1) + 1) // 2, # W (without zeros) + ], [ + # Stride of the output matrix. + *matrix_stride[:-2], + matrix_stride[-2] * (s + 1), # Move s+1 rows. + matrix_stride[-1] * (s + 1), # Move s+1 pixels. + ]) + + +def unify_shape(shape): + """Convert shapes to 4-tuple.""" + if shape is not None: + if isinstance(shape, int): + shape = (shape, shape, shape, shape) + if len(shape) == 2: + shape = (shape[1], shape[1], shape[0], shape[0]) + assert len(shape) == 4 + return shape + + +def is_shape_used(shape, expected=0): + if isinstance(shape, int): + return shape != expected + else: + return sum(shape) != expected + + +class Patches: + """ + A special class which denotes a convoluntional operator as a group of patches + the shape of Patches.patches is [batch_size, num_of_patches, out_channel, in_channel, M, M] + M is the size of a single patch + Assume that we have a conv2D layer with w.weight(out_channel, in_channel, M, M), stride and padding applied on an image (N * N) + num_of_patches = ((N + padding * 2 - M)//stride + 1) ** 2 + Here we only consider kernels with the same H and W + """ + def __init__( + self, patches=None, stride=1, padding=0, shape=None, identity=0, + unstable_idx=None, output_shape=None, inserted_zeros=0, output_padding=0, input_shape=None): + # Shape: [batch_size, num_of_patches, out_channel, in_channel, M, M] + # M is the size of a single patch + # Assume that we have a conv2D layer with w.weight(out_channel, in_channel, M, M), stride and padding applied on an image (N * N) + # num_of_patches = ((N + padding * 2 - M)//stride + 1) ** 2 + # Here we only consider kernels with the same H and W + self.patches = patches + self.stride = stride + self.padding = padding + self.shape = shape + self.identity = identity + self.unstable_idx = unstable_idx + self.output_shape = output_shape + self.input_shape = input_shape + self.inserted_zeros = inserted_zeros + self.output_padding = output_padding + self.simplify() + + def __add__(self, other): + if isinstance(other, Patches): + # Insert images with zero to make stride the same, if necessary. + assert self.stride == other.stride + if self.unstable_idx is not None or other.unstable_idx is not None: + if self.unstable_idx is not other.unstable_idx: # Same tuple object. + raise ValueError('Please set bound option "sparse_conv_intermediate_bounds" to False to run this model.') + assert self.output_shape == other.output_shape + A1 = self.patches + A2 = other.patches + # change paddings to merge the two patches + sp = torch.tensor(unify_shape(self.padding)) + op = torch.tensor(unify_shape(other.padding)) + if (sp - op).abs().sum().item() > 0: + if (sp - op >= 0).all(): + A2 = F.pad(A2, (sp - op).tolist()) + pass + elif (sp - op <= 0).all(): + A1 = F.pad(A1, (op - sp).tolist()) + else: + raise ValueError("Unsupported padding size") + ret = A1 + A2 + return Patches(ret, other.stride, torch.max(sp, op).tolist(), + ret.shape, unstable_idx=self.unstable_idx, output_shape=self.output_shape, + inserted_zeros=self.inserted_zeros, output_padding=self.output_padding) + else: + assert self.inserted_zeros == 0 + assert not is_shape_used(self.output_padding) + # Patches has shape (out_c, batch, out_h, out_w, in_c, h, w). + input_shape = other.shape[3:] + matrix = other + pieces = self.patches + if pieces.ndim == 9: + pieces = pieces.transpose(0, 1) + pieces = pieces.view(pieces.shape[0], -1, pieces.shape[3], pieces.shape[4], pieces.shape[5]*pieces.shape[6], pieces.shape[7], pieces.shape[8]).transpose(0,1) + if pieces.ndim == 8: + pieces = pieces.transpose(0, 1) + pieces = pieces.view(pieces.shape[0], -1, pieces.shape[3], pieces.shape[4], pieces.shape[5], pieces.shape[6], pieces.shape[7]).transpose(0,1) + A1_matrix = patches_to_matrix( + pieces, input_shape, self.stride, self.padding, + output_shape=self.output_shape, unstable_idx=self.unstable_idx) + return A1_matrix.transpose(0, 1) + matrix + + def create_similar(self, patches=None, stride=None, padding=None, identity=None, + unstable_idx=None, output_shape=None, inserted_zeros=None, output_padding=None, + input_shape=None): + """ + Create a new Patches object with new patches weights, and keep other properties the same. + """ + new_patches = self.patches if patches is None else patches + return Patches( + new_patches, + stride=self.stride if stride is None else stride, + padding=self.padding if padding is None else padding, + shape=new_patches.shape, + identity=self.identity if identity is None else identity, + unstable_idx=self.unstable_idx if unstable_idx is None else unstable_idx, + output_shape=self.output_shape if output_shape is None else output_shape, + inserted_zeros=self.inserted_zeros if inserted_zeros is None else inserted_zeros, + output_padding=self.output_padding if output_padding is None else output_padding, + input_shape=self.input_shape if input_shape is None else input_shape, + ) + + def to_matrix(self, input_shape): + assert self.inserted_zeros == 0 + assert not is_shape_used(self.output_padding) + return patches_to_matrix(self.patches, input_shape, self.stride, self.padding, self.output_shape, self.unstable_idx) + + def simplify(self): + """Merge stride and inserted_zeros; if they are the same they can cancel out.""" + stride = [self.stride, self.stride] if isinstance(self.stride, int) else self.stride + if self.inserted_zeros > 0 and self.inserted_zeros + 1 == stride[0] and stride[0] == stride[1]: + # print(f'before simplify: patches={self.patches.size()} padding={self.padding}, stride={self.stride}, output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}') + full_stride = [stride[1], stride[1], stride[0], stride[0]] + # output_padding = tuple(p // s for p, s in zip(output_padding, full_stride)) + self.padding = tuple(p // s - o for p, s, o in zip(self.padding, full_stride, unify_shape(self.output_padding))) + self.patches = remove_zeros(self.patches, self.inserted_zeros) + self.stride = 1 + self.inserted_zeros = 0 + self.output_padding = 0 + # print(f'after simplify: patches={self.patches.size()} padding={self.padding}, stride={self.stride}, output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}') + + def matmul(self, input, patch_abs=False, input_shape=None): + """ + Broadcast multiplication for patches and a matrix. + + Input shape: (batch_size, in_c, in_h, in_w). + If the dim of in_c, in_h, in_w = 1, the the input will be expand by given input_shape to support broadcast + + Output shape: [batch_size, unstable_size] when unstable_idx is not None, + [batch_size, out_c, out_h, out_w] when unstable_idx is None, + """ + + patches = self.patches + if patch_abs: + patches = patches.abs() + + if input_shape is not None: + # For cases that input only has fewer dimensions like (1, in_c, 1, 1) + input = input.expand(input_shape) + # Expand to (batch_size, in_c, in_h, in_w) + + # unfold the input as [batch_size, out_h, out_w, in_c, H, W] + unfold_input = inplace_unfold( + input, kernel_size=patches.shape[-2:], + padding=self.padding, stride=self.stride, + inserted_zeros=self.inserted_zeros, output_padding=self.output_padding) + if self.unstable_idx is not None: + # We need to add a out_c dimension and select from it. + unfold_input = unfold_input.unsqueeze(0).expand(self.output_shape[1], -1, -1, -1, -1, -1, -1) + # Shape: [unstable_size, batch_size, in_c, H, W]. + # Here unfold_input will match this shape. + unfold_input = unfold_input[self.unstable_idx[0], :, self.unstable_idx[1], self.unstable_idx[2]] + # shape: [batch_size, unstable_size]. + return torch.einsum('sbchw,sbchw->bs', unfold_input, patches) + else: + # shape: [batch_size, out_c, out_h, out_w]. + return torch.einsum('bijchw,sbijchw->bsij', unfold_input, patches) + + +def compute_patches_stride_padding(input_shape, patches_padding, patches_stride, op_padding, op_stride, inserted_zeros=0, output_padding=0, simplify=True): + """ + Compute stride and padding after a conv layer with patches mode. + """ + for p in (patches_padding, patches_stride, op_padding, op_stride): + assert isinstance(p, int) or (isinstance(p, (list, tuple)) and (len(p) == 2 or len(p) == 4)) + # If p is int, then same padding on all 4 sides. + # If p is 2-tuple, then it is padding p[0] on both sides of H, p[1] on both sides of W + # If p is 4-tuple, then it is padding p[2], p[3] on top and bottom sides of H, p[0] and p[1] on left and right sides of W + + # If any of the inputs are not tuple/list, we convert them to tuple. + full_patch_padding, full_op_padding, full_patch_stride, full_op_stride = [ + (p, p) if isinstance(p, int) else p for p in [patches_padding, op_padding, patches_stride, op_stride]] + full_patch_padding, full_op_padding, full_patch_stride, full_op_stride = [ + (p[1], p[1], p[0], p[0]) if len(p) == 2 else p for p in [full_patch_padding, full_op_padding, full_patch_stride, full_op_stride]] + # Compute the new padding and stride after this layer. + new_padding = tuple(pp * os + op * (inserted_zeros + 1) for pp, op, os in zip(full_patch_padding, full_op_padding, full_op_stride)) + new_stride = tuple(ps * os for ps, os in zip(full_patch_stride, full_op_stride)) + + output_padding = unify_shape(output_padding) + new_output_padding = (output_padding[0], # Left + output_padding[1] + inserted_zeros * input_shape[3] % full_op_stride[2], # Right + output_padding[2], # Top + output_padding[3] + inserted_zeros * input_shape[2] % full_op_stride[0]) # Bottom + + # Merge into a single number if all numbers are identical. + if simplify: + if new_padding.count(new_padding[0]) == len(new_padding): + new_padding = new_padding[0] + if new_stride.count(new_stride[0]) == len(new_stride): + new_stride = new_stride[0] + + return new_padding, new_stride, new_output_padding + + +def patches_to_matrix(pieces, input_shape, stride, padding, output_shape=None, unstable_idx=None): + """Converting a Patches piece into a full dense matrix.""" + if type(padding) == int: + padding = (padding, padding, padding, padding) + + if pieces.ndim == 9: + # Squeeze two additional dimensions for output and input respectively + assert pieces.shape[1] == 1 and pieces.shape[5] == 1 + pieces = pieces.reshape( + pieces.shape[0], *pieces.shape[2:5], + *pieces.shape[6:] + ) + + if unstable_idx is None: + assert pieces.ndim == 7 + # Non-sparse pieces, with shape (out_c, batch, out_h, out_w, c, h, w). + output_channel, batch_size, output_x, output_y = pieces.shape[:4] + else: + batch_size = pieces.shape[1] + output_channel, output_x, output_y = output_shape[1:] + input_channel, kernel_x, kernel_y = pieces.shape[-3:] + input_x, input_y = input_shape[-2:] + + if unstable_idx is None: + # Fix all patches in a full A matrix. + A_matrix = torch.zeros(batch_size, output_channel, output_x, output_y, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device, dtype=pieces.dtype) + # Save its orignal stride. + orig_stride = A_matrix.stride() + # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution. + # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient. + matrix_strided = torch.as_strided(A_matrix, [batch_size, output_channel, output_x, output_y, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], orig_stride[2], orig_stride[3], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[4], input_y + padding[0] + padding[1], 1]) + # Now we need to fill the conv kernel parameters into the last three dimensions of matrix_strided. + first_indices = torch.arange(output_x * output_y, device=pieces.device) + second_indices = torch.div(first_indices, output_y, rounding_mode="trunc") + third_indices = torch.fmod(first_indices, output_y) + # pieces have shape (out_c, batch, out_h, out_w, c, h, w). + pieces = pieces.transpose(0, 1) # pieces has the out_c dimension at the front, need to move it to the second. + matrix_strided[:,:,second_indices,third_indices,second_indices,third_indices,:,:,:] = pieces.reshape(*pieces.shape[:2], -1, *pieces.shape[4:]) + A_matrix = A_matrix.view(batch_size, output_channel * output_x * output_y, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1]) + else: + # Fill only a selection of patches. + # Create only a partial A matrix. + unstable_size = unstable_idx[0].numel() + A_matrix = torch.zeros(batch_size, unstable_size, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device, dtype=pieces.dtype) + # Save its orignal stride. + orig_stride = A_matrix.stride() + # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution. + # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient. + matrix_strided = torch.as_strided(A_matrix, [batch_size, unstable_size, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[2], input_y + padding[0] + padding[1], 1]) + # pieces have shape (unstable_size, batch, c, h, w). + first_indices = torch.arange(unstable_size, device=pieces.device) + matrix_strided[:,first_indices,unstable_idx[1],unstable_idx[2],:,:,:] = pieces.transpose(0, 1).to(matrix_strided) + A_matrix = A_matrix.view(batch_size, unstable_size, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1]) + + A_matrix = A_matrix[:,:,:,padding[2]:input_x + padding[2],padding[0]:input_y + padding[0]] + + return A_matrix + + +def check_patch_biases(lb, ub, lower_b, upper_b): + # When we use patches mode, it's possible that we need to add two bias + # one is from the Tensor mode and one is from the patches mode + # And we need to detect this case and reshape the bias + if lower_b.ndim < lb.ndim: + lb = lb.transpose(0,1).reshape(lb.size(1), lb.size(0), -1) + lb = lb.expand(lb.size(0), lb.size(1), lower_b.size(0)//lb.size(1)) + lb = lb.reshape(lb.size(0), -1).t() + ub = ub.transpose(0,1).reshape(ub.size(1), ub.size(0), -1) + ub = ub.expand(ub.size(0), ub.size(1), upper_b.size(0)//ub.size(1)) + ub = ub.reshape(ub.size(0), -1).t() + elif lower_b.ndim > lb.ndim: + lower_b = lower_b.transpose(0,1).reshape(lower_b.size(1), -1).t() + upper_b = upper_b.transpose(0,1).reshape(upper_b.size(1), -1).t() + return lb, ub, lower_b, upper_b + + +def inplace_unfold(image, kernel_size, stride=1, padding=0, inserted_zeros=0, output_padding=0): + # Image has size (batch_size, channel, height, width). + assert image.ndim == 4 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) # (left, right, top, bottom). + if len(padding) == 2: # (height direction, width direction). + padding = (padding[1], padding[1], padding[0], padding[0]) + if isinstance(output_padding, int): + output_padding = (output_padding, output_padding, output_padding, output_padding) # (left, right, top, bottom). + if len(output_padding) == 2: # (height direction, width direction). + output_padding = (output_padding[1], output_padding[1], output_padding[0], output_padding[0]) + if isinstance(stride, int): + stride = (stride, stride) # (height direction, width direction). + assert len(kernel_size) == 2 and len(padding) == 4 and len(stride) == 2 + # Make sure the image is large enough for the kernel. + assert image.size(2) + padding[2] + padding[3] >= kernel_size[0] and image.size(3) + padding[0] + padding[1] >= kernel_size[1] + if inserted_zeros > 0: + # We first need to insert zeros in the image before unfolding. + image = insert_zeros(image, inserted_zeros) + # padding = (padding[0], padding[1] + 1, padding[2], padding[3] + 1) + # Compute the number of patches. + # Formulation: https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold + patches_h = int((image.size(2) + padding[2] + padding[3] - (kernel_size[0] - 1) - 1) / stride[0] + 1) + patches_w = int((image.size(3) + padding[0] + padding[1] - (kernel_size[1] - 1) - 1) / stride[1] + 1) + # Pad image. + if sum(padding) != 0: + image = torch.nn.functional.pad(image, padding) + # Save its orignal stride. + image_stride = image.stride() + matrix_strided = torch.as_strided(image, [ + # Shape of the output matrix. + image.size(0), # Batch size. + patches_h, # indices for each patch. + patches_w, + image.size(1), # Channel. + kernel_size[0], # indices for each pixel on a patch. + kernel_size[1]], [ + # Stride of the output matrix. + image_stride[0], # Batch size dimension, keep using the old stride. + image_stride[2] * stride[0], # Move patch in the height dimension. + image_stride[3] * stride[1], # Move patch in the width dimension. + image_stride[1], # Move to the next channel. + image_stride[2], # Move to the next row. + image_stride[3]]) # Move a pixel (on the width direction). + # Output shape is (batch_size, patches_h, patches_w, channel, kernel_height, kernel_width) + if sum(output_padding) > 0: + output_padding = tuple(p if p > 0 else None for p in output_padding) + output_padding = (output_padding) + matrix_strided = matrix_strided[:, output_padding[2]:-output_padding[3] if output_padding[3] is not None else None, + output_padding[0]:-output_padding[1] if output_padding[1] is not None else None, :, :, :] + return matrix_strided + + +def maybe_unfold_patches(d_tensor, last_A, alpha_lookup_idx=None): + """ + Utility function to handle patch mode bound propagation in activation functions. + In patches mode, we need to unfold lower and upper slopes (as input "d_tensor"). In matrix mode we simply return. + """ + if d_tensor is None or last_A is None or isinstance(last_A, Tensor): + return d_tensor + + # Shape for d_tensor: + # sparse: [spec, batch, in_c, in_h, in_w] + # non-sparse (partially shared): [out_c, batch, in_c, in_h, in_w] + # non-sparse (not shared): [out_c*out_h*out_w, batch, in_c, in_h, in_w] + # shared (independent of output spec): [1, batch, in_c, in_h, in_w] + # The in_h, in_w dimensions must be unfolded as patches. + origin_d_shape = d_tensor.shape + if d_tensor.ndim == 6: + # Merge the (out_h, out_w) dimensions. + d_tensor = d_tensor.view(*origin_d_shape[:2], -1, *origin_d_shape[-2:]) + d_shape = d_tensor.size() + # Reshape to 4-D tensor to unfold. + d_tensor = d_tensor.view(-1, *d_tensor.shape[-3:]) + # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W). + d_unfolded = inplace_unfold( + d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, + padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, + output_padding=last_A.output_padding) + # Reshape to the original shape of d, e.g., for non-sparse it is (out_c, batch, out_h, out_w, in_c, H, W). + d_unfolded_r = d_unfolded.view(*d_shape[:-3], *d_unfolded.shape[1:]) + if last_A.unstable_idx is not None: + # Here we have d for all output neurons, but we only need to select unstable ones. + if d_unfolded_r.size(0) == 1: + # Shared alpha, spasre alpha should not be used. + assert alpha_lookup_idx is None + if len(last_A.unstable_idx) == 3: + # Broadcast the spec shape, so only need to select the rest dimensions. + # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W). + d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5) + d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] + elif len(last_A.unstable_idx) == 4: + # [spec, batch, output_h, output_w, input_c, H, W] + # to [output_h, output_w, batch, in_c, H, W] + d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5) + d_unfolded_r = d_unfolded_r[last_A.unstable_idx[2], last_A.unstable_idx[3]] + else: + raise NotImplementedError() + # output shape: (unstable_size, batch, in_c, H, W). + else: + # The spec dimension may be sparse and contains unstable neurons for the spec layer only. + if alpha_lookup_idx is None: + # alpha is spec-dense. Possible because the number of unstable neurons may decrease. + if last_A.size(0) == d_unfolded_r.size(0): + # Non spec-sparse, partially shared alpha among output channel dimension. + # Shape after unfolding is (out_c, batch, out_h, out_w, in_c, patch_h, patch_w). + d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]] + else: + # Non spec-sparse, non-shared alpha. + # Shape after unfolding is (out_c*out_h*out_w, batch, out_h, out_w, in_c, patch_h, patch_w). + # Reshaped to (out_c, out_h, out_w, batch, out_h, out_w, in_c, patch_h, patch_w). + d_unfolded_r = d_unfolded_r.view(last_A.shape[0], last_A.shape[2], last_A.shape[3], -1, *d_unfolded_r.shape[2:]) + # Select on all out_c, out_h, out_w dimensions. + d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], last_A.unstable_idx[1], + last_A.unstable_idx[2], :, last_A.unstable_idx[1], last_A.unstable_idx[2]] + elif alpha_lookup_idx.ndim == 1: + # sparse alpha: [spec, batch, in_c, in_h, in_w] + # Partially shared alpha on the spec dimension - all output neurons on the same channel use the same alpha. + # If alpha_lookup_idx is not None, we need to convert the sparse indices using alpha_lookup_idx. + _unstable_idx = alpha_lookup_idx[last_A.unstable_idx[0]] + # The selection is only used on the channel dimension. + d_unfolded_r = d_unfolded_r[_unstable_idx, :, last_A.unstable_idx[1], last_A.unstable_idx[2]] + elif alpha_lookup_idx is not None and alpha_lookup_idx.ndim == 3: + # sparse alpha: [spec, batch, in_c, in_h, in_w] + # We created alpha as full output shape; alpha not shared among channel dimension. + # Shape of alpha is (out_c*out_h*out_w, batch, in_c, in_h, in_w), note that the first 3 dimensions + # is merged into one to allow simpler selection. + _unstable_idx = alpha_lookup_idx[ + last_A.unstable_idx[0], + last_A.unstable_idx[1], + last_A.unstable_idx[2]] + # d_unfolded_r shape from (out_c, batch, out_h, out_w, in_c, in_h, in_w) + # to (out_c * out_h * out_w(sparse), batch, in_c, in_h, in_w) + # Note that the dimensions out_h, out_w come from unfolding, not specs in alpha, so they will be selected + # directly without translating using the lookup table. + d_unfolded_r = d_unfolded_r[_unstable_idx, :, last_A.unstable_idx[1], last_A.unstable_idx[2]] + # after selection we return (unstable_size, batch_size, in_c, H, W) + return d_unfolded_r + else: + raise ValueError + else: + # A is not sparse. Alpha shouldn't be sparse as well. + assert alpha_lookup_idx is None + if last_A.patches.size(0) != d_unfolded_r.size(0) and d_unfolded_r.size(0) != 1: + # Non-shared alpha, shape after unfolding is (out_c*out_h*out_w, batch, out_h, out_w, in_c, patch_h, patch_w). + # Reshaped to (out_c, out_h*out_w, batch, out_h*out_w, in_c, patch_h, patch_w). + d_unfolded_r = d_unfolded_r.reshape(last_A.shape[0], last_A.shape[2] * last_A.shape[3], -1, + d_unfolded_r.shape[2] * d_unfolded_r.shape[3], *d_unfolded_r.shape[4:]) + # Select the "diagonal" elements in the out_h*out_w dimension. + # New shape is (out_c, batch, in_c, patch_h, patch_w, out_h*out_w) + d_unfolded_r = d_unfolded_r.diagonal(offset=0, dim1=1, dim2=3) + # New shape is (out_c, batch, in_c, patch_h, patch_w, out_h, out_w) + d_unfolded_r = d_unfolded_r.view(*d_unfolded_r.shape[:-1], last_A.shape[2], last_A.shape[3]) + # New shape is (out_c, batch, out_h, out_w, in_c, patch_h, patch_w) + d_unfolded_r = d_unfolded_r.permute(0, 1, 5, 6, 2, 3, 4) + + + # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W). + # For regular patches, the shape after unfold is (out_c, batch, out_h, out_w, in_c, H, W). + if d_unfolded_r.ndim != last_A.patches.ndim: + # For the situation of d independent of output neuron (e.g., vanilla crown bound), which does not have + # the out_h, out_w dimension and out_c = 1 (sepc). We added 1s for the out_h, out_w dimensions. + d_unfolded_r = d_unfolded_r.unsqueeze(2).unsqueeze(-4) + return d_unfolded_r diff --git a/auto_LiRPA/perturbations.py b/auto_LiRPA/perturbations.py index c3fc386..e198388 100644 --- a/auto_LiRPA/perturbations.py +++ b/auto_LiRPA/perturbations.py @@ -1,18 +1,18 @@ -import os import json import math import numpy as np import torch -import torch.nn as nn -from .utils import logger, eyeC, LinearBound, Patches, BoundList, patches_to_matrix, inplace_unfold -import torch.nn.functional as F +from .utils import logger, eyeC +from .patches import Patches, patches_to_matrix +from .linear_bound import LinearBound + class Perturbation: r""" Base class for a perturbation specification. Please see examples - at `auto_LiRPA/perturbations.py`. + at `auto_LiRPA/perturbations.py`. - Examples: + Examples: * `PerturbationLpNorm`: Lp-norm (p>=1) perturbation. @@ -26,7 +26,7 @@ def __init__(self): def set_eps(self, eps): self.eps = eps - + def concretize(self, x, A, sign=-1, aux=None): r""" Concretize bounds according to the perturbation specification. @@ -54,7 +54,7 @@ def init(self, x, aux=None, forward=False): aux (object, optional): Auxilary information. - forward (bool): It indicates whether forward mode LiRPA is involved. + forward (bool): It indicates whether forward mode LiRPA is involved. Returns: bound (LinearBound): Initialized bounds. @@ -69,16 +69,16 @@ def init(self, x, aux=None, forward=False): """Perturbation constrained by the L_0 norm (assuming input data is in the range of 0-1).""" class PerturbationL0Norm(Perturbation): - def __init__(self, eps, x_L = None, x_U = None, ratio = 1.0): + def __init__(self, eps, x_L=None, x_U=None, ratio=1.0): self.eps = eps self.x_U = x_U self.x_L = x_L self.ratio = ratio - def concretize(self, x, A, sign = -1, aux = None): + def concretize(self, x, A, sign=-1, aux=None): if A is None: return None - + eps = math.ceil(self.eps) x = x.reshape(x.shape[0], -1, 1) center = A.matmul(x) @@ -89,7 +89,6 @@ def concretize(self, x, A, sign = -1, aux = None): neg_mask = A < 0 pos_mask = A >= 0 - if sign == 1: A_diff = torch.zeros_like(A) A_diff[pos_mask] = A[pos_mask] - original[pos_mask]# changes that one weight can contribute to the value @@ -117,198 +116,193 @@ def init(self, x, aux=None, forward=False): eye = torch.eye(dim).to(x.device).unsqueeze(0).repeat(batch_size, 1, 1) lw = eye.reshape(batch_size, dim, *x.shape[1:]) lb = torch.zeros_like(x).to(x.device) - uw, ub = lw.clone(), lb.clone() + uw, ub = lw.clone(), lb.clone() return LinearBound(lw, lb, uw, ub, x_L, x_U), x, None def __repr__(self): return 'PerturbationLpNorm(norm=0, eps={})'.format(self.eps) + """Perturbation constrained by the L_p norm.""" class PerturbationLpNorm(Perturbation): - def __init__(self, eps=0, norm=np.inf, x_L=None, x_U=None, relative=False): + def __init__(self, eps=0, norm=np.inf, x_L=None, x_U=None): self.eps = eps self.norm = norm self.dual_norm = 1 if (norm == np.inf) else (np.float64(1.0) / (1 - 1.0 / self.norm)) self.x_L = x_L self.x_U = x_U - self.relative = relative + self.sparse = False - """Given an variable x and its bound matrix A, compute worst case bound according to Lp norm.""" - def concretize(self, x, A, sign=-1, aux=None, extra_constr=None): - if A is None: - return None - # If A is an identity matrix, we will handle specially. - def concretize_matrix(A): - nonlocal x - if not isinstance(A, eyeC): - # A has (Batch, spec, *input_size). For intermediate neurons, spec is *neuron_size. - A = A.reshape(A.shape[0], A.shape[1], -1) + def get_input_bounds(self, x, A): + if self.sparse: + if self.x_L_sparse.shape[-1] == A.shape[-1]: + x_L, x_U = self.x_L_sparse, self.x_U_sparse + else: + # In backward mode, A is not sparse + x_L, x_U = self.x_L, self.x_U + else: + x_L = x - self.eps if self.x_L is None else self.x_L + x_U = x + self.eps if self.x_U is None else self.x_U + return x_L, x_U + + # If A is an identity matrix, we will handle specially. + def concretize_matrix(self, x, A, sign, extra_constr): + if not isinstance(A, eyeC): + # A has (Batch, spec, *input_size). For intermediate neurons, spec is *neuron_size. + A = A.reshape(A.shape[0], A.shape[1], -1) + + if extra_constr is not None: + # For each neuron, we have a beta, so beta size is (Batch, *neuron_size, n_beta) (in A, spec is *neuron_size). + # For intermediate layer neurons, A has *neuron_size specifications. + beta = extra_constr['beta'] + beta = beta.view(beta.size(0), -1, beta.size(-1)) + # coeffs are linear relationships between split neurons and x. They have size (batch, n_beta, *input_size), and unreated to neuron_size. + beta_coeffs = extra_constr['coeffs'] + beta_coeffs = beta_coeffs.view(beta_coeffs.size(0), beta_coeffs.size(1), -1) + # biases are added for each batch each spec, size is (batch, n_beta), and unrelated to neuron_size. + beta_bias = extra_constr['bias'] + # Merge beta into extra A and bias. Extra A has size (batch, spec, *input_size). For intermediate neurons, spec is *neuron_size. + extra_A = torch.einsum('ijk,ikl->ijl', beta, beta_coeffs) + # Merge beta into the bias term. Output has size (batch, spec). + extra_bias = torch.einsum('ijk,ik->ij', beta, beta_bias) + if self.norm == np.inf: + # For Linfinity distortion, when an upper and lower bound is given, we use them instead of eps. + x_L, x_U = self.get_input_bounds(x, A) + x_ub = x_U.reshape(x_U.shape[0], -1, 1) + x_lb = x_L.reshape(x_L.shape[0], -1, 1) + # Find the uppwer and lower bound similarly to IBP. + center = (x_ub + x_lb) / 2.0 + diff = (x_ub - x_lb) / 2.0 + if not isinstance(A, eyeC): if extra_constr is not None: - # For each neuron, we have a beta, so beta size is (Batch, *neuron_size, n_beta) (in A, spec is *neuron_size). - # For intermediate layer neurons, A has *neuron_size specifications. - beta = extra_constr['beta'] - beta = beta.view(beta.size(0), -1, beta.size(-1)) - # coeffs are linear relationships between split neurons and x. They have size (batch, n_beta, *input_size), and unreated to neuron_size. - beta_coeffs = extra_constr['coeffs'] - beta_coeffs = beta_coeffs.view(beta_coeffs.size(0), beta_coeffs.size(1), -1) - # biases are added for each batch each spec, size is (batch, n_beta), and unrelated to neuron_size. - beta_bias = extra_constr['bias'] - # Merge beta into extra A and bias. Extra A has size (batch, spec, *input_size). For intermediate neurons, spec is *neuron_size. - extra_A = torch.einsum('ijk,ikl->ijl', beta, beta_coeffs) - # Merge beta into the bias term. Output has size (batch, spec). - extra_bias = torch.einsum('ijk,ik->ij', beta, beta_bias) - - if self.norm == np.inf: - # For Linfinity distortion, when an upper and lower bound is given, we use them instead of eps. - x_L = x - self.eps if self.x_L is None else self.x_L - x_U = x + self.eps if self.x_U is None else self.x_U - x_ub = x_U.reshape(x_U.shape[0], -1, 1) - x_lb = x_L.reshape(x_L.shape[0], -1, 1) - # Find the uppwer and lower bound similarly to IBP. - center = (x_ub + x_lb) / 2.0 - diff = (x_ub - x_lb) / 2.0 - if not isinstance(A, eyeC): - if extra_constr is not None: - # Extra linear and bias terms from constraints. - print( - f'A extra: {(sign * extra_A).abs().sum().item()}, b extra: {(sign * extra_bias).abs().sum().item()}') - A = A - sign * extra_A - bound = A.matmul(center) - sign * extra_bias.unsqueeze(-1) + sign * A.abs().matmul(diff) - else: - bound = A.matmul(center) + sign * A.abs().matmul(diff) + # Extra linear and bias terms from constraints. + print( + f'A extra: {(sign * extra_A).abs().sum().item()}, ' + f'b extra: {(sign * extra_bias).abs().sum().item()}') + A = A - sign * extra_A + bound = A.matmul(center) - sign * extra_bias.unsqueeze(-1) + sign * A.abs().matmul(diff) else: - assert extra_constr is None - # A is an identity matrix. No need to do this matmul. - bound = center + sign * diff + bound = A.matmul(center) + sign * A.abs().matmul(diff) else: assert extra_constr is None - x = x.reshape(x.shape[0], -1, 1) - if not isinstance(A, eyeC): - # Find the upper and lower bounds via dual norm. - deviation = A.norm(self.dual_norm, -1) * self.eps - bound = A.matmul(x) + sign * deviation.unsqueeze(-1) - else: - # A is an identity matrix. Its norm is all 1. - bound = x + sign * self.eps - bound = bound.squeeze(-1) - return bound + # A is an identity matrix. No need to do this matmul. + bound = center + sign * diff + else: + assert extra_constr is None + x = x.reshape(x.shape[0], -1, 1) + if not isinstance(A, eyeC): + # Find the upper and lower bounds via dual norm. + deviation = A.norm(self.dual_norm, -1) * self.eps + bound = A.matmul(x) + sign * deviation.unsqueeze(-1) + else: + # A is an identity matrix. Its norm is all 1. + bound = x + sign * self.eps + bound = bound.squeeze(-1) + return bound - def concretize_patches(A): - nonlocal x - if self.norm == np.inf: - # For Linfinity distortion, when an upper and lower bound is given, we use them instead of eps. - x_L = x - self.eps if self.x_L is None else self.x_L - x_U = x + self.eps if self.x_U is None else self.x_U - - # Here we should not reshape - # Find the uppwer and lower bound similarly to IBP. - center = (x_U + x_L) / 2.0 - diff = (x_U - x_L) / 2.0 - if not A.identity == 1: - # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W] or [unstable_size, batch_size, in_c, H, W]. Here out_c is the spec dimension. - patches = A.patches - - # unfold the input as [batch_size, out_h, out_w, in_c, H, W] - unfold_input = inplace_unfold(center, kernel_size=A.patches.shape[-2:], padding = A.padding, stride = A.stride) - if A.unstable_idx is not None: - # We need to add a out_c dimension and select from it. - unfold_input = unfold_input.unsqueeze(0).expand(A.output_shape[1], -1, -1, -1, -1, -1, -1) - # When A is sparse, the shape is [unstable_size, batch_size, in_c, H, W]. Here unfold_input will match this shape. - unfold_input = unfold_input[A.unstable_idx[0], :, A.unstable_idx[1], A.unstable_idx[2]] - # size of bound: [batch_size, unstable_size]. - bound = torch.einsum('sbchw,sbchw->bs', unfold_input, patches) - else: - # size of bound: [batch_size, out_c, out_h, out_w]. - bound = torch.einsum('bijchw,sbijchw->bsij', unfold_input, patches) - - # unfold the diff as [batch_size, out_h, out_w, in_c, H, W] - unfold_diff = inplace_unfold(diff, kernel_size=A.patches.shape[-2:], padding = A.padding, stride = A.stride) - if A.unstable_idx is not None: - # We need to add a out_c dimension and select from it. - unfold_diff = unfold_diff.unsqueeze(0).expand(A.output_shape[1], -1, -1, -1, -1, -1, -1) - # When A is sparse, the shape is [unstable_size, batch_size, in_c, H, W] - unfold_diff = unfold_diff[A.unstable_idx[0], :, A.unstable_idx[1], A.unstable_idx[2]] - # size of diff: [batch_size, unstable_size]. - bound_diff = torch.einsum('sbchw,sbchw->bs', unfold_diff, patches.abs()) - else: - # size of diff: [batch_size, out_c, out_h, out_w]. - bound_diff = torch.einsum('bijchw,sbijchw->bsij', unfold_diff, patches.abs()) - - if sign == 1: - bound += bound_diff - elif sign == -1: - bound -= bound_diff - else: - raise ValueError("Unsupported Sign") - - # The extra bias term from beta term. - if extra_constr is not None: - bound += extra_constr - else: - assert extra_constr is None - # A is an identity matrix. No need to do this matmul. - bound = center + sign * diff - return bound - else: # Lp norm - # x_L = x - self.eps if self.x_L is None else self.x_L - # x_U = x + self.eps if self.x_U is None else self.x_U - - input_shape = x.shape - if not A.identity: - # Find the upper and lower bounds via dual norm. - # matrix has shape (batch_size, out_c * out_h * out_w, input_c, input_h, input_w) or (batch_size, unstable_size, input_c, input_h, input_w) - matrix = patches_to_matrix(A.patches, input_shape, A.stride, A.padding, A.output_shape, A.unstable_idx) - # Note that we should avoid reshape the matrix. Due to padding, matrix cannot be reshaped without copying. - deviation = matrix.norm(p=self.dual_norm, dim=(-3,-2,-1)) * self.eps - # Bound has shape (batch, out_c * out_h * out_w) or (batch, unstable_size). - bound = torch.einsum('bschw,bchw->bs', matrix, x) + sign * deviation - if A.unstable_idx is None: - # Reshape to (batch, out_c, out_h, out_w). - bound = bound.view(matrix.size(0), A.patches.size(0), A.patches.size(2), A.patches.size(3)) + def concretize_patches(self, x, A, sign, extra_constr): + if self.norm == np.inf: + x_L, x_U = self.get_input_bounds(x, A) + + # Here we should not reshape + # Find the uppwer and lower bound similarly to IBP. + center = (x_U + x_L) / 2.0 + diff = (x_U - x_L) / 2.0 + + if not A.identity == 1: + bound = A.matmul(center) + bound_diff = A.matmul(diff, patch_abs=True) + + if sign == 1: + bound += bound_diff + elif sign == -1: + bound -= bound_diff else: - # A is an identity matrix. Its norm is all 1. - bound = x + sign * self.eps - return bound + raise ValueError("Unsupported Sign") + + # The extra bias term from beta term. + if extra_constr is not None: + bound += extra_constr + else: + assert extra_constr is None + # A is an identity matrix. No need to do this matmul. + bound = center + sign * diff + return bound + else: # Lp norm + input_shape = x.shape + if not A.identity: + # Find the upper and lower bounds via dual norm. + # matrix has shape (batch_size, out_c * out_h * out_w, input_c, input_h, input_w) or (batch_size, unstable_size, input_c, input_h, input_w) + matrix = patches_to_matrix(A.patches, input_shape, A.stride, A.padding, A.output_shape, A.unstable_idx) + # Note that we should avoid reshape the matrix. Due to padding, matrix cannot be reshaped without copying. + deviation = matrix.norm(p=self.dual_norm, dim=(-3,-2,-1)) * self.eps + # Bound has shape (batch, out_c * out_h * out_w) or (batch, unstable_size). + bound = torch.einsum('bschw,bchw->bs', matrix, x) + sign * deviation + if A.unstable_idx is None: + # Reshape to (batch, out_c, out_h, out_w). + bound = bound.view(matrix.size(0), A.patches.size(0), A.patches.size(2), A.patches.size(3)) + else: + # A is an identity matrix. Its norm is all 1. + bound = x + sign * self.eps + return bound + """Given an variable x and its bound matrix A, compute worst case bound according to Lp norm.""" + def concretize(self, x, A, sign=-1, aux=None, extra_constr=None): + if A is None: + return None if isinstance(A, eyeC) or isinstance(A, torch.Tensor): - return concretize_matrix(A) + return self.concretize_matrix(x, A, sign, extra_constr) elif isinstance(A, Patches): - return concretize_patches(A) - elif isinstance(A, BoundList): - for b in A.bound_list: - if isinstance(b, eyeC) or isinstance(b, torch.Tensor): - pass + return self.concretize_patches(x, A, sign, extra_constr) else: raise NotImplementedError() + """ Sparse Linf perturbation where only a few dimensions are actually perturbed""" + def init_sparse_linf(self, x, x_L, x_U): + self.sparse = True + batch_size = x_L.shape[0] + perturbed = (x_U > x_L).int() + logger.debug(f'Perturbed: {perturbed.sum()}') + lb = ub = x_L * (1 - perturbed) # x_L=x_U holds when perturbed=0 + perturbed = perturbed.view(batch_size, -1) + index = torch.cumsum(perturbed, dim=-1) + dim = max(perturbed.view(batch_size, -1).sum(dim=-1).max(), 1) + self.x_L_sparse = torch.zeros(batch_size, dim + 1).to(x_L) + self.x_L_sparse.scatter_(dim=-1, index=index, src=(x_L - lb).view(batch_size, -1), reduce='add') + self.x_U_sparse = torch.zeros(batch_size, dim + 1).to(x_U) + self.x_U_sparse.scatter_(dim=-1, index=index, src=(x_U - ub).view(batch_size, -1), reduce='add') + self.x_L_sparse, self.x_U_sparse = self.x_L_sparse[:, 1:], self.x_U_sparse[:, 1:] + lw = torch.zeros(batch_size, dim + 1, perturbed.shape[-1], device=x.device) + perturbed = perturbed.to(torch.get_default_dtype()) + lw.scatter_(dim=1, index=index.unsqueeze(1), src=perturbed.unsqueeze(1)) + lw = uw = lw[:, 1:, :].view(batch_size, dim, *x.shape[1:]) + print(f'Using Linf sparse perturbation. Perturbed dimensions: {dim}.') + print(f'Avg perturbation: {(self.x_U_sparse - self.x_L_sparse).mean()}') + return LinearBound( + lw, lb, uw, ub, x_L, x_U), x, None + def init(self, x, aux=None, forward=False): + self.sparse = False if self.norm == np.inf: x_L = x - self.eps if self.x_L is None else self.x_L x_U = x + self.eps if self.x_U is None else self.x_U else: # For other norms, we pass in the BoundedTensor objects directly. - x_L = x - x_U = x - if self.relative: - nominal = x - lower_offset = torch.max(x_L - x - 1e-8, torch.ones_like(x_L) * (-self.eps)) - upper_offset = torch.min(x_U - x + 1e-8, torch.ones_like(x_U) * (self.eps)) - else: - nominal = lower_offset = upper_offset = None + x_L = x_U = x if not forward: return LinearBound( - None, None, None, None, x_L, x_U, - nominal=nominal, lower_offset=lower_offset, upper_offset=upper_offset), x, None + None, None, None, None, x_L, x_U), x, None + if self.norm == np.inf and x_L.numel() > 1 and (x_L == x_U).sum() > 0.5 * x_L.numel(): + return self.init_sparse_linf(x, x_L, x_U) + batch_size = x.shape[0] dim = x.reshape(batch_size, -1).shape[-1] + lb = ub = torch.zeros_like(x) eye = torch.eye(dim).to(x).expand(batch_size, dim, dim) - lw = eye.reshape(batch_size, dim, *x.shape[1:]) - lb = torch.zeros_like(x).to(x.device) - uw, ub = lw.clone(), lb.clone() + lw = uw = eye.reshape(batch_size, dim, *x.shape[1:]) return LinearBound( - lw, lb, uw, ub, x_L, x_U, - nominal=nominal, lower_offset=lower_offset, upper_offset=upper_offset), x, None + lw, lb, uw, ub, x_L, x_U), x, None def __repr__(self): if self.norm == np.inf: @@ -354,7 +348,7 @@ def concretize(self, x, A, sign, aux): num_pos = int(np.max(np.sum(can_be_replaced, axis=-1))) update_A = A.shape[-1] > num_pos * dim_word if update_A: - bias = torch.bmm(A, (x * (1 - mask_rep).unsqueeze(-1)).reshape(batch_size, -1, 1)).squeeze(-1) + bias = torch.bmm(A, (x * (1 - mask_rep).unsqueeze(-1)).reshape(batch_size, -1, 1)).squeeze(-1) else: bias = 0. A = A.reshape(batch_size, dim_out, -1, dim_word) @@ -386,7 +380,7 @@ def concretize(self, x, A, sign, aux): mask = torch.cat(mask_new).reshape(batch_size, num_pos, max_num_cand) length = num_pos - A = A.reshape(batch_size, A.shape[1], length, -1).transpose(1, 2) + A = A.reshape(batch_size, A.shape[1], length, -1).transpose(1, 2) x = x.reshape(batch_size, length, -1, 1) if sign == 1: @@ -396,8 +390,8 @@ def concretize(self, x, A, sign, aux): init_tensor = torch.ones(batch_size, dim_out).to(x.device) * init dp = [[init_tensor] * (self.budget + 1) for i in range(0, length + 1)] - dp[0][0] = torch.zeros(batch_size, dim_out).to(x.device) - + dp[0][0] = torch.zeros(batch_size, dim_out).to(x.device) + A = A.reshape(batch_size * length, A.shape[2], A.shape[3]) Ax = torch.bmm( A, @@ -418,10 +412,10 @@ def concretize(self, x, A, sign, aux): dp[i][0] = dp[i - 1][0] + Ax[:, i - 1] for j in range(1, self.budget + 1): dp[i][j] = cmp( - dp[i - 1][j] + Ax[:, i - 1], + dp[i - 1][j] + Ax[:, i - 1], dp[i - 1][j - 1] + Ax_rep_bound[:, i - 1] ) - dp = torch.cat(dp[length], dim=0).reshape(self.budget + 1, batch_size, dim_out) + dp = torch.cat(dp[length], dim=0).reshape(self.budget + 1, batch_size, dim_out) return cmp(dp, dim=0).values + bias @@ -432,7 +426,7 @@ def init(self, x, aux=None, forward=False): batch_size, length, dim_word = x.shape[0], x.shape[1], x.shape[2] max_pos = 1 - can_be_replaced = np.zeros((batch_size, length), dtype=bool) + can_be_replaced = np.zeros((batch_size, length), dtype=np.bool) self._build_substitution(batch) @@ -457,8 +451,8 @@ def init(self, x, aux=None, forward=False): if forward: eye = torch.eye(dim_word).to(x.device) lw = torch.zeros(batch_size, dim, length, dim_word).to(x.device) - lb = torch.zeros_like(x).to(x.device) - x_new = [] + lb = torch.zeros_like(x).to(x.device) + x_new = [] word_embeddings = self.model.word_embeddings.weight vocab = self.model.vocab x_rep = [[[] for i in range(length)] for t in range(batch_size)] @@ -467,8 +461,8 @@ def init(self, x, aux=None, forward=False): candidates = batch[t]['candidates'] # for transformers if tokens[t][0] == '[CLS]': - candidates = [[]] + candidates + [[]] - cnt = 0 + candidates = [[]] + candidates + [[]] + cnt = 0 for i in range(length): if can_be_replaced[t][i]: word_embed = word_embeddings[vocab[tokens[t][i]]] @@ -485,13 +479,13 @@ def init(self, x, aux=None, forward=False): cnt += 1 else: if forward: - lb[t, i, :] = x[t, i, :] + lb[t, i, :] = x[t, i, :] if forward: uw, ub = lw, lb else: lw = lb = uw = ub = None zeros = torch.zeros(dim_word, device=x.device) - + x_rep_, mask = [], [] for t in range(batch_size): for i in range(length): @@ -501,7 +495,7 @@ def init(self, x, aux=None, forward=False): mask = torch.tensor(mask, dtype=torch.float32, device=x.device)\ .reshape(batch_size, length, max_num_cand) x_rep_ = x_rep_ * self.eps + x.unsqueeze(2) * (1 - self.eps) - + inf = 1e20 lower = torch.min(mask.unsqueeze(-1) * x_rep_ + (1 - mask).unsqueeze(-1) * inf, dim=2).values upper = torch.max(mask.unsqueeze(-1) * x_rep_ + (1 - mask).unsqueeze(-1) * (-inf), dim=2).values diff --git a/auto_LiRPA/solver_module.py b/auto_LiRPA/solver_module.py new file mode 100644 index 0000000..9720ce5 --- /dev/null +++ b/auto_LiRPA/solver_module.py @@ -0,0 +1,132 @@ +import multiprocessing +import multiprocessing.pool +import sys +import os + +import torch +from .bound_ops import * + + +def build_solver_module(self, x=None, C=None, intermediate_layer_bounds=None, + final_node_name=None, model_type="mip", solver_pkg="gurobi"): + r"""build lp/mip solvers in general graph. + + Args: + x: inputs, a list of BoundedTensor. If set to None, we reuse exisint bounds that + were previously computed in compute_bounds(). + C (Tensor): The specification matrix that can map the output of the model with an + additional linear layer. This is usually used for maping the logits output of the + model to classification margins. + intermediate_layer_bounds: if specified, will replace existing intermediate layer bounds. + Otherwise we reuse exising intermediate bounds. + + final_node_name (String): the name for the target layer to optimize + + solver_pkg (String): the backbone of the solver, default gurobi, also support scipy + + Returns: + output vars (list): a list of final nodes to optimize + """ + # self.root_name: list of root node name + # self.final_name: list of output node name + # self.final_node: output module + # .input: a list of input modules of this layer module + # .solver_vars: a list of gurobi vars of every layer module + # list with conv shape if conv layers, otherwise flattened + # if last layer we need to be careful with: + # C: specification matrix + # .is_input_perturbed(1) + + if x is not None: + assert intermediate_layer_bounds is not None + # Set the model to use new intermediate layer bounds, ignore the original ones. + self._set_input(x, intermediate_layer_bounds=intermediate_layer_bounds) + + root = [self[name] for name in self.root_name] + + # create interval ranges for input and other weight parameters + for i in range(len(root)): + value = root[i].forward() + # if isinstance(root[i], BoundInput) and not isinstance(root[i], BoundParams): + if type(root[i]) is BoundInput: + # create input vars for gurobi self.model + inp_gurobi_vars = self._build_solver_input(root[i]) + else: + # regular weights + root[i].solver_vars = value + + final = self.final_node() if final_node_name is None else self[final_node_name] + + # backward propagate every layer including last layer + self._build_solver_general(node=final, C=C, model_type=model_type, solver_pkg=solver_pkg) + + # a list of output solver vars + return final.solver_vars + + +def _build_solver_general(self, node, C=None, model_type="mip", solver_pkg="gurobi"): + if not hasattr(node, 'solver_vars'): + for n in node.inputs: + self._build_solver_general(n, C=C, model_type=model_type, solver_pkg=solver_pkg) + inp = [n_pre.solver_vars for n_pre in node.inputs] + # print(node, node.inputs) + if C is not None and isinstance(node, BoundLinear) and\ + not node.is_input_perturbed(1) and self.final_name == node.name: + # when node is the last layer + # merge the last BoundLinear node with the specification, + # available when weights of this layer are not perturbed + solver_vars = node.build_solver(*inp, model=self.model, C=C, + model_type=model_type, solver_pkg=solver_pkg) + else: + solver_vars = node.build_solver(*inp, model=self.model, C=None, + model_type=model_type, solver_pkg=solver_pkg) + # just return output node gurobi vars + return solver_vars + + +def _build_solver_input(self, node): + ## Do the input layer, which is a special case + assert isinstance(node, BoundInput) + assert node.perturbation is not None + assert node.perturbation.norm == float("inf") + inp_gurobi_vars = [] + # zero var will be shared within the solver model + zero_var = self.model.addVar(lb=0, ub=0, obj=0, vtype=grb.GRB.CONTINUOUS, name='zero') + x_L = node.value - node.perturbation.eps if node.perturbation.x_L is None else node.perturbation.x_L + x_U = node.value + node.perturbation.eps if node.perturbation.x_U is None else node.perturbation.x_U + x_L = x_L.squeeze(0) + x_U = x_U.squeeze(0) + # x_L, x_U = node.lower.squeeze(0), node.upper.squeeze(0) + + if x_L.ndim == 1: + # This is a linear input. + for dim, (lb, ub) in enumerate(zip(x_L, x_U)): + v = self.model.addVar(lb=lb, ub=ub, obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'inp_{dim}') + inp_gurobi_vars.append(v) + else: + assert x_L.ndim == 3, f"x_L ndim {x_L.ndim}" + dim = 0 + for chan in range(x_L.shape[0]): + chan_vars = [] + for row in range(x_L.shape[1]): + row_vars = [] + for col in range(x_L.shape[2]): + lb = x_L[chan, row, col] + ub = x_U[chan, row, col] + v = self.model.addVar(lb=lb, ub=ub, obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'inp_{dim}') + # name=f'inp_[{chan},{row},{col}]') + row_vars.append(v) + dim += 1 + chan_vars.append(row_vars) + inp_gurobi_vars.append(chan_vars) + + node.solver_vars = inp_gurobi_vars + # save the gurobi input variables so that we can later extract primal values in input space easily + self.input_vars = inp_gurobi_vars + self.model.update() + return inp_gurobi_vars + diff --git a/auto_LiRPA/utils.py b/auto_LiRPA/utils.py index a6c4280..e50ea0c 100644 --- a/auto_LiRPA/utils.py +++ b/auto_LiRPA/utils.py @@ -3,15 +3,18 @@ import time import torch import torch.nn as nn +import torch.nn.functional as F import os import sys import appdirs -from collections import defaultdict, Sequence, namedtuple +from collections import defaultdict, namedtuple +from collections.abc import Sequence from functools import reduce import operator import math -import torch.nn.functional as F import warnings +from typing import Tuple +from .patches import Patches, insert_zeros logging.basicConfig( format='%(levelname)-8s %(asctime)-12s %(message)s', @@ -19,7 +22,7 @@ stream=sys.stdout ) logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) +logger.setLevel(logging.DEBUG if os.environ.get('AUTOLIRPA_DEBUG', 0) else logging.INFO) warnings.simplefilter("once") @@ -35,6 +38,24 @@ reduction_max = lambda x: x.max(1, keepdim=True).values reduction_min = lambda x: x.min(1, keepdim=True).values +MIN_HALF_FP = 5e-8 # 2**-24, which is the smallest value that float16 can be represented + + +def reduction_str2func(reduction_func): + if type(reduction_func) == str: + if reduction_func == 'min': + return reduction_min + elif reduction_func == 'max': + return reduction_max + elif reduction_func == 'sum': + return reduction_sum + elif reduction_func == 'mean': + return reduction_mean + else: + raise NotImplementedError(f'Unknown reduction_func {reduction_func}') + else: + return reduction_func + def stop_criterion_sum(threshold=0): return lambda x: (x.sum(1, keepdim=True) > threshold) @@ -47,34 +68,22 @@ def stop_criterion_min(threshold=0): def stop_criterion_max(threshold=0): return lambda x: (x.max(1, keepdim=True).values > threshold) -# Create a namedtuple with defaults -def namedtuple_with_defaults(name, attr, defaults): - assert sys.version_info.major == 3 - if sys.version_info.major >= 7: - return namedtuple(name, attr, defaults=defaults) - else: - # The defaults argument is not available in Python < 3.7 - t = namedtuple(name, attr) - t.__new__.__defaults__ = defaults - return t - -# A special class which denotes a convoluntional operator as a group of patches -# the shape of Patches.patches is [batch_size, num_of_patches, out_channel, in_channel, M, M] -# M is the size of a single patch -# Assume that we have a conv2D layer with w.weight(out_channel, in_channel, M, M), stride and padding applied on an image (N * N) -# num_of_patches = ((N + padding * 2 - M)//stride + 1) ** 2 -# Here we only consider kernels with the same H and W -Patches = namedtuple_with_defaults('Patches', ('patches', 'stride', 'padding', 'shape', 'identity', 'unstable_idx', 'output_shape'), defaults=(None, 1, 0, None, 0, None, None)) -BoundList = namedtuple_with_defaults('BoundList', ('bound_list'), defaults=([],)) -# Linear bounds with coefficients. Used for forward bound propagation. -LinearBound = namedtuple_with_defaults('LinearBound', ('lw', 'lb', 'uw', 'ub', 'lower', 'upper', 'from_input', 'nominal', 'lower_offset', 'upper_offset'), defaults=(None,) * 10) - -# for debugging -if False: - file_handler = logging.FileHandler('debug.log') - file_handler.setFormatter(logging.Formatter('%(levelname)-8s %(asctime)-12s %(message)s')) - logger.addHandler(file_handler) - logger.setLevel(logging.DEBUG) +def stop_criterion_batch(threshold=0): + # may unexpected broadcast, pay attention to the shape of threshold + # x shape: batch, number_bounds; threshold shape: batch, number_bounds + # print('threshold', threshold.shape) + return lambda x: (x > threshold) + +def stop_criterion_batch_any(threshold=0): + # may unexpected broadcast, pay attention to the shape of threshold + # x shape: batch, number_bounds; threshold shape: batch, number_bounds + # print('threshold', threshold.shape) + return lambda x: (x > threshold).any(dim=1) + +def stop_criterion_batch_topk(threshold=0, k=1314): + # x shape: batch, number_bounds; threshold shape: batch, number_bounds + # print('threshold', threshold.shape) + return lambda x: (torch.kthvalue(x, k, dim=-1, keepdim=True).values > threshold).any(dim=1) user_data_dir = appdirs.user_data_dir('auto_LiRPA') if not os.path.exists(user_data_dir): @@ -171,7 +180,7 @@ def scale_gradients(optimizer, gradient_accumulation_steps, grad_clip=None): if param.grad is not None: param.grad.data /= gradient_accumulation_steps if grad_clip is not None: - torch.nn.utils.clip_grad_norm_(parameters, grad_clip) + return torch.nn.utils.clip_grad_norm_(parameters, grad_clip) def recursive_map (seq, func): for item in seq: @@ -224,55 +233,6 @@ def batched_index_select(input, dim, index): return torch.gather(input, dim, index) -"""Converting a Patches piece into a full dense matrix.""" -def patches_to_matrix(pieces, input_shape, stride, padding, output_shape=None, unstable_idx=None): - if type(padding) == int: - padding = (padding, padding, padding, padding) - if output_shape is None: - assert pieces.ndim == 7 - # Non-sparse pieces, with shape (out_c, batch, out_h, out_w, c, h, w). - output_channel, batch_size, output_x, output_y = pieces.shape[:4] - else: - batch_size, output_channel, output_x, output_y = output_shape - input_channel, kernel_x, kernel_y = pieces.shape[-3:] - input_x, input_y = input_shape[-2:] - - if unstable_idx is None: - # Fix all patches in a full A matrix. - A_matrix = torch.zeros(batch_size, output_channel, output_x, output_y, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device) - # Save its orignal stride. - orig_stride = A_matrix.stride() - # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution. - # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient. - matrix_strided = torch.as_strided(A_matrix, [batch_size, output_channel, output_x, output_y, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], orig_stride[2], orig_stride[3], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[4], input_y + padding[0] + padding[1], 1]) - # Now we need to fill the conv kernel parameters into the last three dimensions of matrix_strided. - first_indices = torch.arange(output_x * output_y, device=pieces.device) - second_indices = torch.div(first_indices, output_y, rounding_mode="trunc") - third_indices = torch.fmod(first_indices, output_y) - # pieces have shape (out_c, batch, out_h, out_w, c, h, w). - pieces = pieces.transpose(0, 1) # pieces has the out_c dimension at the front, need to move it to the second. - matrix_strided[:,:,second_indices,third_indices,second_indices,third_indices,:,:,:] = pieces.reshape(*pieces.shape[:2], -1, *pieces.shape[4:]) - A_matrix = A_matrix.view(batch_size, output_channel * output_x * output_y, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1]) - else: - # Fill only a selection of patches. - # Create only a partial A matrix. - unstable_size = unstable_idx[0].numel() - A_matrix = torch.zeros(batch_size, unstable_size, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device) - # Save its orignal stride. - orig_stride = A_matrix.stride() - # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution. - # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient. - matrix_strided = torch.as_strided(A_matrix, [batch_size, unstable_size, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[2], input_y + padding[0] + padding[1], 1]) - # pieces have shape (unstable_size, batch, c, h, w). - first_indices = torch.arange(unstable_size, device=pieces.device) - matrix_strided[:,first_indices,unstable_idx[1],unstable_idx[2],:,:,:] = pieces.transpose(0,1) - A_matrix = A_matrix.view(batch_size, unstable_size, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1]) - - A_matrix = A_matrix[:,:,:,padding[2]:input_x + padding[2],padding[0]:input_y + padding[0]] - - return A_matrix - - def check_padding(x, padding): if isinstance(padding, int): return x, (padding, padding) @@ -283,47 +243,6 @@ def check_padding(x, padding): return F.pad(x, padding), (0, 0) -def inplace_unfold(image, kernel_size, stride=1, padding=0): - # Image has size (batch_size, channel, height, width). - assert image.ndim == 4 - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - if isinstance(padding, int): - padding = (padding, padding, padding, padding) # (left, right, top, bottom). - if len(padding) == 2: # (height direction, width direction). - padding = (padding[1], padding[1], padding[0], padding[0]) - if isinstance(stride, int): - stride = (stride, stride) # (height direction, width direction). - assert len(kernel_size) == 2 and len(padding) == 4 and len(stride) == 2 - # Make sure the image is large enough for the kernel. - assert image.size(2) + padding[2] + padding[3] >= kernel_size[0] and image.size(3) + padding[0] + padding[1] >= kernel_size[1] - # Compute the number of patches. - # Formulation: https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold - patches_h = int((image.size(2) + padding[2] + padding[3] - (kernel_size[0] - 1) - 1) / stride[0] + 1) - patches_w = int((image.size(3) + padding[0] + padding[1] - (kernel_size[1] - 1) - 1) / stride[1] + 1) - # Pad image. - if sum(padding) != 0: - image = torch.nn.functional.pad(image, padding) - # Save its orignal stride. - image_stride = image.stride() - matrix_strided = torch.as_strided(image, [ - # Shape of the output matrix. - image.size(0), # Batch size. - patches_h, # indices for each patch. - patches_w, - image.size(1), # Channel. - kernel_size[0], # indices for each pixel on a patch. - kernel_size[1]], [ - # Stride of the output matrix. - image_stride[0], # Batch size dimension, keep using the old stride. - image_stride[2] * stride[0], # Move patch in the height dimension. - image_stride[3] * stride[1], # Move patch in the width dimension. - image_stride[1], # Move to the next channel. - image_stride[2], # Move to the next row. - image_stride[3]]) # Move a pixel (on the width direction). - # Output shape is (batch_size, patches_h, patches_w, channel, kernel_height, kernel_width) - return matrix_strided - def get_spec_matrix(X, y, num_classes): with torch.no_grad(): c = (torch.eye(num_classes).type_as(X)[y].unsqueeze(1) @@ -331,3 +250,40 @@ def get_spec_matrix(X, y, num_classes): I = (~(y.unsqueeze(1) == torch.arange(num_classes).type_as(y).unsqueeze(0))) c = (c[I].view(X.size(0), num_classes - 1, num_classes)) return c + +def unravel_index( + indices: torch.LongTensor, + shape: Tuple[int, ...], +) -> torch.LongTensor: + r"""Converts flat indices into unraveled coordinates in a target shape. + + Args: + indices: A tensor of (flat) indices, (*, N). + shape: The targeted shape, (D,). + + Returns: + The unraveled coordinates, a list with tensors in shape (N, D). + + Code borrowed from: + https://github.com/pytorch/pytorch/issues/35674 + """ + + coord = [] + + for dim in reversed(shape): + coord.append(indices % dim) + indices = torch.div(indices, dim, rounding_mode='trunc') + + return list(reversed(coord)) + +def get_A_shape(A): + if A is None: + return 'None' + if isinstance(A, Patches): + if A.patches is not None: + return A.patches.shape + else: + return A.shape + if isinstance(A, torch.Tensor): + return A.shape + return 'Unknown' diff --git a/doc/README.md b/doc/README.md index 48a25fd..c4b102b 100644 --- a/doc/README.md +++ b/doc/README.md @@ -1,7 +1,7 @@ - # Documentation This directory contains source files for building our documentation. +Please view the compiled documentation on our [documentation page](https://auto-lirpa.readthedocs.io/en/latest/?badge=latest), as some links may not work here on GitHub. ## Dependencies diff --git a/doc/api.rst b/doc/api.rst index f70bf4b..5375cb5 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -18,8 +18,6 @@ API Usage .. autofunction:: auto_LiRPA.perturbations.Perturbation.concretize .. autofunction:: auto_LiRPA.perturbations.Perturbation.init -.. autofunction:: auto_LiRPA.bound_op_map.register_custom_op - Indices and tables ------------------- diff --git a/doc/process.py b/doc/process.py index 93ffc48..6ad1116 100644 --- a/doc/process.py +++ b/doc/process.py @@ -5,7 +5,7 @@ from pygit2 import Repository repo = 'https://github.com/KaidiXu/auto_LiRPA' -branch = os.environ.get('BRANCH', None) or Repository('.').head.shorthand +branch = Repository('.').head.shorthand repo_file_path = os.path.join(repo, 'tree', branch) """ Parse README.md into sections which can be reused """ @@ -40,7 +40,6 @@ source = file.read() source_new = '' ptr = 0 - # res = re.findall('\[.*\]\(.*\)', source) for m in re.finditer('(\[.*\])(\(.*\))', source): assert m.start() >= ptr source_new += source[ptr:m.start()] diff --git a/doc/src/bound_opts.md b/doc/src/bound_opts.md index 4fc9f9c..3d21b72 100644 --- a/doc/src/bound_opts.md +++ b/doc/src/bound_opts.md @@ -1,34 +1,41 @@ Bound Options ==================== -Bound options can be set by passing a dictionary to the `bound_opts` argument for `BoundedModule`. +Bound options can be set by passing a dictionary to the `bound_opts` argument for `BoundedModule`. This page lists available bound options. ## Arguments for Optimizing Bounds (`optimize_bound_args`) Arguments for optimizing bounds with the `CROWN-Optimized` method can be provided as a dictionary. Available arguments include: -* `ob_alpha` (bool, default `True`): Enable α-CROWN (optimized CROWN/LiRPA). +* `enable_alpha_crown` (bool, default `True`): Enable α-CROWN (optimized CROWN/LiRPA). -* `ob_beta` (bool, default `False`): Enable β-CROWN. +* `enable_beta_crown` (bool, default `False`): Enable β-CROWN. -* `ob_optimizer` (str, default `adam`): Optimzier. Set it to `adam-autolr` to use `AdamElementLR`, or `sgd` to use SGD. +* `optimizer` (str, default `adam`): Optimzier. Set it to `adam-autolr` to use `AdamElementLR`, or `sgd` to use SGD. -* `ob_verbose` (int, default 0): If greater than 1, print verbosely. +* `lr_alpha` (float, default 0.5), `lr_beta` (default 0.05): Learning rates for α and β parameters in α-CROWN and β-CROWN. -* `ob_lr` (default 0.5), `ob_lr_beta` (default 0.05): Learning rates for α and β parameters in α-CROWN and β-CROWN. +* `lr_decay` (float, default 0.98): Learning rate decay factor for the `ExponentialLR` scheduler. -* `ob_lr_decay` (default 0.98): Learning rate decay factor for the `ExponentialLR` scheduler. +* `iteration` (int): Number of optimization iterations. -* `ob_iteration` (int): Number of optimization iterations. +* `loss_reduction_func` (function): Function for loss reduction over the specification dimension. By default, use `auto_LiRPA.utils.reduction_sum` which sumes the bound over all batch elements and specifications. -* `ob_loss_reduction_func` (function): Function for loss reduction over the specification dimension. By default, use `auto_LiRPA.utils.reduction_sum` which sumes the bound over all batch elements and specifications. +* `stop_criterion_func` (function): Function for the criterion of stopping optimization early; it returns a tensor of `torch.bool` with `batch_size` elements. By default, it is a lambda function that always returns `False` . Several pre-defined options are `auto_LiRPA.utils.stop_criterion_min`, `auto_LiRPA.utils.stop_criterion_mean`, `auto_LiRPA.utils.stop_criterion_max` and `auto_LiRPA.utils.stop_criterion_sum`. For example, `auto_LiRPA.utils.stop_criterion_min` checks the minimum bound over all specifications of a batch element and returns `True` for that element when the minimum bound is greater than a specified threshold. -* `ob_stop_criterion_func` (function): Function for the criterion of stopping optimization early; it returns a tensor of `torch.bool` with `batch_size` elements. By default, it is a lambda function that always returns `False` . Several pre-defined options are `auto_LiRPA.utils.stop_criterion_min`, `auto_LiRPA.utils.stop_criterion_mean`, `auto_LiRPA.utils.stop_criterion_max` and `auto_LiRPA.utils.stop_criterion_sum`. For example, `auto_LiRPA.utils.stop_criterion_min` checks the minimum bound over all specifications of a batch element and returns `True` for that element when the minimum bound is greater than a specified threshold. +* `keep_best` (bool, default `True`): If `True`, save α, β and bounds at the best iteration. Otherwise the last iteration result is used. -* `ob_keep_best` (bool, default `True`): If `True`, save α, β parameters at the best iteration. Otherwise the last iteration result is used. +* `use_shared_alpha` (bool, default `False`): If `True`, all intermediate neurons from the same layer share the same set of α variables during bound optimization. For a very large model, enabling this option can save memory, at a cost of slightly looser bound. + +* `fix_intermediate_layer_bounds` (bool, default `True`): Only optimize bounds of last layer during alpha/beta CROWN. + +* `init_alpha` (bool, default `True`): Initial alpha variables by calling CROWN once. + +* `early_stop_patience` (int, default, 10): Number of iterations that we will start considering early stop if tracking no improvement. + +* `start_save_best` (float, default 0.5): Start to save optimized best bounds when current_iteration > int(iteration*start_save_best) -* `ob_alpha_share_slopes` (bool, default `False`): If `True`, all intermediate neurons from the same layer share the same set of α variables during bound optimization. For a very large model, enabling this option can save memory, at a cost of slightly looser bound. ## ReLU (`relu`): diff --git a/doc/src/custom_op.md b/doc/src/custom_op.md index 853aa48..cc62038 100644 --- a/doc/src/custom_op.md +++ b/doc/src/custom_op.md @@ -13,7 +13,7 @@ There are three steps to write an operator: 3. Implement a [Bound class](api.html#auto_LiRPA.bound_ops.Bound) to support bound propagation methods for this operator. -4. [Register the custom operator](api.html#auto_LiRPA.bound_op_map.register_custom_op). +4. Create a mapping from the operator name (defined in step 1) to the bound class (defined in step 3). Define a `dict` which each item is a mapping. Pass the `dict` to the `custom_ops` argument when calling `BoundedModule` (see the [documentation](api.html#auto_LiRPA.BoundedModule)). For example, if the operator name is `MyRelu`, and the bound class is `BoundMyRelu`, then add `"MyRelu": BoundMyRelu` to the `dict`. ## Example diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 0000000..c067428 --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1 @@ +auto_LiRPA diff --git a/examples/language/train.py b/examples/language/train.py index 1d6d2a3..8c6f5b2 100644 --- a/examples/language/train.py +++ b/examples/language/train.py @@ -1,13 +1,14 @@ -""" -A simple script to train certified defense LSTM or Transformer using the auto_LiRPA library. -""" - import argparse +import random import pickle import os +import pdb +import time import logging import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as F from torch.nn import CrossEntropyLoss from torch.utils.tensorboard import SummaryWriter from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationSynonym, CrossEntropyWrapperMultiInput @@ -69,18 +70,11 @@ args = parser.parse_args() -# log writer in Tensorboard writer = SummaryWriter(os.path.join(args.dir, 'log'), flush_secs=10) file_handler = logging.FileHandler(os.path.join(args.dir, 'log/train.log')) file_handler.setFormatter(logging.Formatter('%(levelname)-8s %(asctime)-12s %(message)s')) logger.addHandler(file_handler) -random.seed(args.seed) -np.random.seed(args.seed) -torch.manual_seed(args.seed) -torch.cuda.manual_seed_all(args.seed) - -## Step 1: Prepare dataset and Initial original model as usual data_train_all_nodes, data_train, data_dev, data_test = load_data(args.data) if args.robust: data_dev, data_test = clean_data(data_dev), clean_data(data_test) @@ -92,6 +86,10 @@ logger.info('Dataset sizes: {}/{}/{}/{}'.format( len(data_train_all_nodes), len(data_train), len(data_dev), len(data_test))) +random.seed(args.seed) +np.random.seed(args.seed) +torch.manual_seed(args.seed) +torch.cuda.manual_seed_all(args.seed) dummy_embeddings = torch.zeros(1, args.max_sent_length, args.embedding_size, device=args.device) dummy_labels = torch.zeros(1, dtype=torch.long, device=args.device) @@ -102,17 +100,14 @@ elif args.model == 'lstm': dummy_mask = torch.zeros(1, args.max_sent_length, device=args.device) model = LSTM(args, data_train) - + dev_batches = get_batches(data_dev, args.batch_size) test_batches = get_batches(data_test, args.batch_size) -## Step 3: Define perturbation range, here we use synonym replacement perturbation constarint by args.budget ptb = PerturbationSynonym(budget=args.budget) dummy_embeddings = BoundedTensor(dummy_embeddings, ptb) - -## Step 4: wrap model with auto_LiRPA model_ori = model.model_from_embeddings -bound_opts = {'relu': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True} +bound_opts = { 'relu': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True } if isinstance(model_ori, BoundedModule): model_bound = model_ori else: @@ -184,11 +179,9 @@ def step(model, ptb, batch, eps=1.0, train=False): acc = (torch.argmax(logits, dim=1) == labels).float().mean() if robust: - # generate specifications num_class = args.num_classes c = torch.eye(num_class).type_as(embeddings)[labels].unsqueeze(1) - \ torch.eye(num_class).type_as(embeddings).unsqueeze(0) - # remove specifications to self I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0))) c = (c[I].view(embeddings.size(0), num_class - 1, num_class)) if args.method in ['IBP', 'IBP+backward', 'forward', 'forward+backward']: @@ -198,15 +191,11 @@ def step(model, ptb, batch, eps=1.0, train=False): if 1 - eps > 1e-4: lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP+backward', bound_upper=False) ilb, iub = model_bound.compute_bounds(aux=aux, C=c, method='IBP', reuse_ibp=True) - # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020) lb = eps * ilb + (1 - eps) * lb else: lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP') else: raise NotImplementedError - - # Pad zero at the beginning for each example, and use fake label "0" for all examples because the margins - # have already been calculated by specifications lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1) fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device) loss_robust = robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels) @@ -237,7 +226,6 @@ def train(epoch, batches, type): assert(optimizer is not None) train = type == 'train' if args.robust: - # epsilon dynamically growth eps_scheduler.set_epoch_length(len(batches)) if train: eps_scheduler.train() diff --git a/examples/requirements.txt b/examples/requirements.txt index 279168b..2aeee45 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,11 +1,12 @@ thop>=0.0.31.post2004101309 tensorboard>=1.14 scikit_learn>=0.21 -torchvision>=0.9.1 -torch>=1.8,<1.9 +torchvision>=0.9.1,<0.13 +torch>=1.8 scipy>=1.3 pytorch_pretrained_bert>=0.6 query>=0.1 tqdm>=4.43 matplotlib>=3.2 sortedcontainers>=2.4 +psutil>=5.8 diff --git a/examples/vision/cifar_training.py b/examples/vision/cifar_training.py index 9feb315..d4019a9 100644 --- a/examples/vision/cifar_training.py +++ b/examples/vision/cifar_training.py @@ -3,11 +3,11 @@ import random import time import logging +import os import torch.optim as optim import torchvision.datasets as datasets import torchvision.transforms as transforms -from thop import profile from torch.nn import CrossEntropyLoss import models @@ -55,7 +55,7 @@ def get_exp_module(bounded_module): os.makedirs('saved_models/', exist_ok=True) log_file = f'saved_models/{exp_name}{"_test" if args.verify else ""}.log' file_handler = logging.FileHandler(log_file) -logger.addHandler(file_handler) +logger.addHandler(file_handler) def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None): num_class = 10 @@ -172,16 +172,6 @@ def get_bound_loss(x=None, c=None): opt.step() meter.update('Loss', loss.item(), data.size(0)) - # check gradient - # for n, p in model.named_parameters(): - # if p.grad is None: - # print('gradient for layer {} is NULL!!!'.format(n)) - # else: - # print('gradient for layer {} is not null'.format(n)) - # print(p.grad.flatten()[:8]) - # - # sys.exit() - if batch_method != 'natural': meter.update('Robust_CE', robust_ce.item(), data.size(0)) if not loss_fusion: @@ -262,9 +252,6 @@ def main(args): final_name2 = None model_loss = BoundDataParallel(model_loss) - macs, params = profile(model_ori, (dummy_input.cuda(),)) - logger.info('macs: {}, params: {}'.format(macs, params)) - ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler opt = optim.Adam(model_loss.parameters(), lr=args.lr) norm = float(args.norm) diff --git a/examples/vision/custom_op.py b/examples/vision/custom_op.py index 1d5d96e..5b70fa0 100644 --- a/examples/vision/custom_op.py +++ b/examples/vision/custom_op.py @@ -1,7 +1,7 @@ """ A example for custom operators. -In this example, we create a custom operator called "PlusConstant", which can -be written as "f(x) = x + c" for some constant "c" (an attribute of the operator). +In this example, we create a custom operator called "PlusConstant", which can +be written as "f(x) = x + c" for some constant "c" (an attribute of the operator). """ import torch import torch.nn as nn @@ -11,22 +11,22 @@ from auto_LiRPA.perturbations import PerturbationLpNorm from auto_LiRPA.utils import Flatten -""" Step 1: Define a `torch.autograd.Function` class to declare and implement the +""" Step 1: Define a `torch.autograd.Function` class to declare and implement the computation of the operator. """ class PlusConstantOp(torch.autograd.Function): @staticmethod def symbolic(g, x, const): """ In this function, define the arguments and attributes of the operator. "custom::PlusConstant" is the name of the new operator, "x" is an argument - of the operator, "const_i" is an attribute which stands for "c" in the operator. - There can be multiple arguments and attributes. For attribute naming, - use a suffix such as "_i" to specify the data type, where "_i" stands for + of the operator, "const_i" is an attribute which stands for "c" in the operator. + There can be multiple arguments and attributes. For attribute naming, + use a suffix such as "_i" to specify the data type, where "_i" stands for integer, "_t" stands for tensor, "_f" stands for float, etc. """ return g.op('custom::PlusConstant', x, const_i=const) @staticmethod def forward(ctx, x, const): - """ In this function, implement the computation for the operator, i.e., + """ In this function, implement the computation for the operator, i.e., f(x) = x + c in this case. """ return x + const @@ -43,9 +43,9 @@ def forward(self, x): """ Step 3: Implement a Bound class to support bound computation for the new operator. """ class BoundPlusConstant(Bound): - def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): + def __init__(self, attr, inputs, output_index, options): """ `const` is an attribute and can be obtained from the dict `attr` """ - super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + super().__init__(attr, inputs, output_index, options) self.const = attr['const'] def forward(self, x): @@ -67,7 +67,7 @@ def _bound_oneside(last_A): bias = last_A.sum(dim=list(range(2, last_A.ndim))) * self.const return A, bias lA, lbias = _bound_oneside(last_lA) - uA, ubias = _bound_oneside(last_uA) + uA, ubias = _bound_oneside(last_lA) return [(lA, uA)], lbias, ubias def interval_propagate(self, *v): diff --git a/examples/vision/data/.gitignore b/examples/vision/data/.gitignore new file mode 100644 index 0000000..13ece7c --- /dev/null +++ b/examples/vision/data/.gitignore @@ -0,0 +1,2 @@ +MNIST +cifar* \ No newline at end of file diff --git a/examples/vision/data/tinyImageNet/.gitignore b/examples/vision/data/tinyImageNet/.gitignore new file mode 100644 index 0000000..0c46dd7 --- /dev/null +++ b/examples/vision/data/tinyImageNet/.gitignore @@ -0,0 +1 @@ +tiny-imagenet-200* diff --git a/examples/vision/imagenet_training.py b/examples/vision/imagenet_training.py index 6e9791f..22425c9 100644 --- a/examples/vision/imagenet_training.py +++ b/examples/vision/imagenet_training.py @@ -54,7 +54,7 @@ def get_exp_module(bounded_module): args.num_epochs) + '_' + args.scheduler_opts + '_' + str(args.eps)[:6] log_file = f'saved_models/{exp_name}{"_test" if args.verify else ""}.log' file_handler = logging.FileHandler(log_file) -logger.addHandler(file_handler) +logger.addHandler(file_handler) def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None): diff --git a/examples/vision/simple_verification.py b/examples/vision/simple_verification.py index 9b7629e..b20c61b 100644 --- a/examples/vision/simple_verification.py +++ b/examples/vision/simple_verification.py @@ -4,17 +4,15 @@ This example serves as a skeleton for robustness verification of neural networks. """ import os +from collections import defaultdict import torch import torch.nn as nn import torchvision from auto_LiRPA import BoundedModule, BoundedTensor from auto_LiRPA.perturbations import PerturbationLpNorm +from auto_LiRPA.utils import Flatten ## Step 1: Define computational graph by implementing forward() -class Flatten(nn.Module): - def forward(self, x): - return x.view(x.size(0), -1) - # This simple model comes from https://github.com/locuslab/convex_adversarial def mnist_model(): model = nn.Sequential( @@ -31,11 +29,15 @@ def mnist_model(): model = mnist_model() # Optionally, load the pretrained weights. -checkpoint = torch.load(os.path.join(os.path.dirname(__file__),"pretrain/mnist_a_adv.pth"), map_location=torch.device('cpu')) +checkpoint = torch.load( + os.path.join(os.path.dirname(__file__), 'pretrain/mnist_a_adv.pth'), + map_location=torch.device('cpu')) model.load_state_dict(checkpoint) ## Step 2: Prepare dataset as usual -test_data = torchvision.datasets.MNIST("./data", train=False, download=True, transform=torchvision.transforms.ToTensor()) +test_data = torchvision.datasets.MNIST( + './data', train=False, download=True, + transform=torchvision.transforms.ToTensor()) # For illustration we only use 2 image from dataset N = 2 n_classes = 10 @@ -48,7 +50,8 @@ def mnist_model(): model = model.cuda() ## Step 3: wrap model with auto_LiRPA -# The second parameter is for constructing the trace of the computational graph, and its content is not important. +# The second parameter is for constructing the trace of the computational graph, +# and its content is not important. lirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device) print('Running on', image.device) @@ -60,23 +63,51 @@ def mnist_model(): # Get model prediction as usual pred = lirpa_model(image) label = torch.argmax(pred, dim=1).cpu().detach().numpy() +print('Demonstration 1: Bound computation and comparisons of different methods.\n') ## Step 5: Compute bounds for final output -for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', 'CROWN-Optimized (alpha-CROWN)']: -# for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', ]: - print("Bounding method:", method) +for method in [ + 'IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', + 'CROWN-Optimized (alpha-CROWN)']: + print('Bounding method:', method) if 'Optimized' in method: # For optimized bound, you can change the number of iterations, learning rate, etc here. Also you can increase verbosity to see per-iteration loss values. - lirpa_model.set_bound_opts({'optimize_bound_args': {'ob_iteration': 20, 'ob_lr': 0.1, 'ob_verbose': 0}}) + lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}}) lb, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0]) for i in range(N): - print("Image {} top-1 prediction {} ground-truth {}".format(i, label[i], true_label[i])) + print(f'Image {i} top-1 prediction {label[i]} ground-truth {true_label[i]}') for j in range(n_classes): indicator = '(ground-truth)' if j == true_label[i] else '' - print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}".format( + print('f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}'.format( j=j, l=lb[i][j].item(), u=ub[i][j].item(), ind=indicator)) print() +print('Demonstration 2: Obtaining linear coefficients of the lower and upper bounds.\n') +# There are many bound coefficients during CROWN bound calculation; here we are interested in the linear bounds +# of the output layer, with respect to the input layer (the image). +required_A = defaultdict(set) +required_A[lirpa_model.output_name[0]].add(lirpa_model.input_name[0]) + +for method in [ + 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', 'CROWN', + 'CROWN-Optimized (alpha-CROWN)']: + print("Bounding method:", method) + if 'Optimized' in method: + # For optimized bound, you can change the number of iterations, learning rate, etc here. Also you can increase verbosity to see per-iteration loss values. + lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}}) + lb, ub, A_dict = lirpa_model.compute_bounds(x=(image,), method=method.split()[0], return_A=True, needed_A_dict=required_A) + lower_A, lower_bias = A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['lA'], A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['lbias'] + upper_A, upper_bias = A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['uA'], A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['ubias'] + print(f'lower bound linear coefficients size (batch, output_dim, *input_dims): {list(lower_A.size())}') + print(f'lower bound linear coefficients norm (smaller is better): {lower_A.norm()}') + print(f'lower bound bias term size (batch, output_dim): {list(lower_bias.size())}') + print(f'lower bound bias term sum (larger is better): {lower_bias.sum()}') + print(f'upper bound linear coefficients size (batch, output_dim, *input_dims): {list(upper_A.size())}') + print(f'upper bound linear coefficients norm (smaller is better): {upper_A.norm()}') + print(f'upper bound bias term size (batch, output_dim): {list(upper_bias.size())}') + print(f'upper bound bias term sum (smaller is better): {upper_bias.sum()}') + print(f'These linear lower and upper bounds are valid everywhere within the perturbation radii.\n') + ## An example for computing margin bounds. # In compute_bounds() function you can pass in a specification matrix C, which is a final linear matrix applied to the last layer NN output. # For example, if you are interested in the margin between the groundtruth class and another class, you can use C to specify the margin. @@ -89,17 +120,18 @@ def mnist_model(): target_label = (groundtruth + 1) % n_classes C.scatter_(dim=2, index=groundtruth, value=1.0) C.scatter_(dim=2, index=target_label, value=-1.0) -print('Computing bounds with a specification matrix:\n', C) +print('Demonstration 3: Computing bounds with a specification matrix.\n') +print('Specification matrix:\n', C) for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', 'CROWN-Optimized (alpha-CROWN)']: - print("Bounding method:", method) + print('Bounding method:', method) if 'Optimized' in method: # For optimized bound, you can change the number of iterations, learning rate, etc here. Also you can increase verbosity to see per-iteration loss values. - lirpa_model.set_bound_opts({'optimize_bound_args': {'ob_iteration': 20, 'ob_lr': 0.1, 'ob_verbose': 0}}) + lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1, }}) lb, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0], C=C) for i in range(N): - print("Image {} top-1 prediction {} ground-truth {}".format(i, label[i], true_label[i])) - print("margin bounds: {l:8.3f} <= f_{j}(x_0+delta) - f_{target}(x_0+delta) <= {u:8.3f}".format( + print('Image {} top-1 prediction {} ground-truth {}'.format(i, label[i], true_label[i])) + print('margin bounds: {l:8.3f} <= f_{j}(x_0+delta) - f_{target}(x_0+delta) <= {u:8.3f}'.format( j=true_label[i], target=(true_label[i] + 1) % n_classes, l=lb[i][0].item(), u=ub[i][0].item())) print() diff --git a/examples/vision/tinyimagenet_training.py b/examples/vision/tinyimagenet_training.py index 3f079a5..3264ca2 100644 --- a/examples/vision/tinyimagenet_training.py +++ b/examples/vision/tinyimagenet_training.py @@ -1,3 +1,4 @@ +import os import random import time import argparse @@ -55,7 +56,7 @@ def get_exp_module(bounded_module): os.makedirs('saved_models/', exist_ok=True) log_file = f'saved_models/{exp_name}{"_test" if args.verify else ""}.log' file_handler = logging.FileHandler(log_file) -logger.addHandler(file_handler) +logger.addHandler(file_handler) def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None): diff --git a/examples/vision/verify_two_node.py b/examples/vision/verify_two_node.py index 7bbd697..9895d8a 100644 --- a/examples/vision/verify_two_node.py +++ b/examples/vision/verify_two_node.py @@ -1,10 +1,11 @@ """ -Example for multi-node perturbation. An input image is splited to two parts -where each part is perturbed respectively constained by L-inf norm. It is -expected to output the same results as running `simple_verification.py` where +Example for multi-node perturbation. An input image is splited to two parts +where each part is perturbed respectively constained by L-inf norm. It is +expected to output the same results as running `simple_verification.py` where the whole image is perturbed constained by L-inf norm. """ +import os import torch.nn as nn import torch.nn.functional as F import torchvision @@ -35,7 +36,8 @@ def forward(self, x, y): model.load_state_dict(checkpoint) ## Step 2: Prepare dataset as usual -test_data = torchvision.datasets.MNIST("./data", train=False, download=True, transform=torchvision.transforms.ToTensor()) +test_data = torchvision.datasets.MNIST( + "./data", train=False, download=True, transform=torchvision.transforms.ToTensor()) # For illustration we only use 2 image from dataset N = 2 n_classes = 10 @@ -44,9 +46,13 @@ def forward(self, x, y): image = image.to(torch.float32) / 255.0 ## Step 3: wrap model with auto_LiRPA -# The second parameter is for constructing the trace of the computational graph, and its content is not important. +# The second parameter is for constructing the trace of the computational graph, +# and its content is not important. image_1, image_2 = torch.split(torch.empty_like(image), [14, 14], dim=2) -model = BoundedModule(model, (image_1, image_2), device="cuda") +model = BoundedModule( + model, (image_1, image_2), device="cuda", + bound_opts={'conv_mode': 'matrix'} # Patches mode is not supported currently +) ## Step 4: Compute bounds using LiRPA given a perturbation eps = 0.3 @@ -68,6 +74,6 @@ def forward(self, x, y): for i in range(N): print("Image {} top-1 prediction {}".format(i, label[i])) for j in range(n_classes): - print("f_{j}(x_0) = {fx0:8.3f}, {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}".format(j=j, fx0=pred[i][j], l=lb[i][j], u=ub[i][j])) + print("f_{j}(x_0) = {fx0:8.3f}, {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}".format( + j=j, fx0=pred[i][j], l=lb[i][j], u=ub[i][j])) print() - diff --git a/examples/vision/weight_perturbation_training.py b/examples/vision/weight_perturbation_training.py index dbc2bef..13bad55 100644 --- a/examples/vision/weight_perturbation_training.py +++ b/examples/vision/weight_perturbation_training.py @@ -11,6 +11,7 @@ """ import random import time +import os import argparse import logging import torch.optim as optim @@ -217,17 +218,22 @@ def main(args): ## Step 3: wrap model with auto_LiRPA # The second parameter dummy_input is for constructing the trace of the computational graph. - model = BoundedModule(model_ori, dummy_input, bound_opts={'relu':args.bound_opts}, device=args.device) + model = BoundedModule(model_ori, dummy_input, device=args.device, bound_opts={ + 'relu':args.bound_opts, 'sparse_intermediate_bounds': False, + 'sparse_conv_intermediate_bounds': False, 'sparse_intermediate_bounds_with_ibp': False}) final_name1 = model.final_name model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)), - bound_opts= { 'relu': args.bound_opts, 'loss_fusion': True }, device=args.device) + device=args.device, bound_opts= {'relu': args.bound_opts, 'loss_fusion': True, + 'sparse_intermediate_bounds': False, + 'sparse_conv_intermediate_bounds': False, + 'sparse_intermediate_bounds_with_ibp': False}) # after CrossEntropyWrapper, the final name will change because of one more input node in CrossEntropyWrapper final_name2 = model_loss._modules[final_name1].output_name[0] assert type(model._modules[final_name1]) == type(model_loss._modules[final_name2]) if args.multigpu: model_loss = BoundDataParallel(model_loss) - model_loss.ptb = model.ptb = model_ori.ptb # Perturbation on the parameters + model_loss.ptb = model.ptb = model_ori.ptb # Perturbation on the parameters ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler if args.opt == 'ADAM': diff --git a/setup.py b/setup.py index 0807d51..5428841 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,8 @@ from setuptools import setup, find_packages -import sys """Check PyTorch version""" pytorch_version_l = "1.8.0" -pytorch_version_u = "1.9.0" # excluded +pytorch_version_u = "1.13.0" # excluded msg_install_pytorch = (f"It is recommended to manually install PyTorch " f"(>={pytorch_version_u},<{pytorch_version_u}) suitable " "for your system ahead: https://pytorch.org/get-started.\n") @@ -14,7 +13,7 @@ + msg_install_pytorch) if torch.__version__ >= pytorch_version_u: print(f'PyTorch version {torch.__version__} is too high. ' - + msg_install_pytorch) + + msg_install_pytorch) except ModuleNotFoundError: print(f'PyTorch is not installed. {msg_install_pytorch}') @@ -23,13 +22,6 @@ if '__version__' in line: version = eval(line.strip().split()[-1]) -assert sys.version_info.major == 3, 'Python 3 is required' -if sys.version_info.minor < 8: - # numpy 1.22 requires Python 3.8+ - numpy_requirement = 'numpy>=1.16,<=1.21' -else: - numpy_requirement = 'numpy>=1.16' - print(f'Installing auto_LiRPA {version}') setup( name='auto_LiRPA', @@ -41,12 +33,15 @@ packages=find_packages(), install_requires=[ f'torch>={pytorch_version_l},<{pytorch_version_u}', - 'torchvision>=0.9,<0.10', - numpy_requirement, + 'torchvision>=0.9,<0.14', + 'numpy>=1.16', 'packaging>=20.0', 'pytest>=5.0', + 'pylint>=2.15', + 'pytest-order>=1.0.0', 'appdirs>=1.4', 'pyyaml>=5.0', + 'ninja>=1.10', ], platforms=['any'], license='BSD', diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 0000000..16d3c4d --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +.cache diff --git a/tests/auto_LiRPA b/tests/auto_LiRPA new file mode 120000 index 0000000..a0acdc3 --- /dev/null +++ b/tests/auto_LiRPA @@ -0,0 +1 @@ +../auto_LiRPA \ No newline at end of file diff --git a/tests/data/.gitignore b/tests/data/.gitignore index 677d9db..c40e49a 100644 --- a/tests/data/.gitignore +++ b/tests/data/.gitignore @@ -2,4 +2,4 @@ ckpt_lstm ckpt_transformer cifar-10-python.tar.gz cifar-10-batches-py - +MNIST \ No newline at end of file diff --git a/tests/data/beta_crown_test_data b/tests/data/beta_crown_test_data new file mode 100644 index 0000000..6090ff8 Binary files /dev/null and b/tests/data/beta_crown_test_data differ diff --git a/tests/data/maxpool_test_data b/tests/data/maxpool_test_data index 4d230bc..b538cc9 100755 Binary files a/tests/data/maxpool_test_data and b/tests/data/maxpool_test_data differ diff --git a/tests/data/maxpool_test_data_3-0-3-0 b/tests/data/maxpool_test_data_3-0-3-0 new file mode 100644 index 0000000..41191fa Binary files /dev/null and b/tests/data/maxpool_test_data_3-0-3-0 differ diff --git a/tests/data/maxpool_test_data_3-0-3-1 b/tests/data/maxpool_test_data_3-0-3-1 new file mode 100644 index 0000000..351e86d Binary files /dev/null and b/tests/data/maxpool_test_data_3-0-3-1 differ diff --git a/tests/data/maxpool_test_data_4-0-4-0 b/tests/data/maxpool_test_data_4-0-4-0 new file mode 100644 index 0000000..e054b4a Binary files /dev/null and b/tests/data/maxpool_test_data_4-0-4-0 differ diff --git a/tests/data/maxpool_test_data_4-0-4-1 b/tests/data/maxpool_test_data_4-0-4-1 new file mode 100644 index 0000000..83f17ef Binary files /dev/null and b/tests/data/maxpool_test_data_4-0-4-1 differ diff --git a/tests/data/vision_test_data b/tests/data/vision_test_data index 31715a5..9f2b7fd 100644 Binary files a/tests/data/vision_test_data and b/tests/data/vision_test_data differ diff --git a/tests/test_1d_activation.py b/tests/test_1d_activation.py index 22dfa2a..232b47b 100644 --- a/tests/test_1d_activation.py +++ b/tests/test_1d_activation.py @@ -1,5 +1,6 @@ """Test one dimensional activation functions (e.g., ReLU, tanh, exp, sin, etc)""" import torch +import torch.nn as nn import os from testcase import TestCase from auto_LiRPA import BoundedModule, BoundedTensor @@ -21,6 +22,8 @@ def __init__(self, methodName='runTest'): super().__init__(methodName) def create_test(self, act_func, low, high, ntests=10000, nsamples=1000, method='IBP'): + print(f'Testing activation {act_func}') + model = test_model(act_func) image = torch.zeros(1, ntests) bounded_model = BoundedModule(model, image) @@ -83,6 +86,22 @@ def lookup(l, u): assert torch.all(output_lb - 1e-5 <= ref_output_lb) + def _single(self): + model = test_model(torch.sin) + image = torch.zeros(1, 1) + bounded_model = BoundedModule(model, image) + + input_lb = torch.tensor([2.817]) + input_ub = torch.tensor([5.196]) + input_center = (input_lb + input_ub) / 2.0 + ptb = PerturbationLpNorm(norm=float("inf"), eps=None, x_L=input_lb, x_U=input_ub) + ptb_data = BoundedTensor(input_center, ptb) + + # Get bounding results. + forward = bounded_model(ptb_data) + output_lb, output_ub = bounded_model.compute_bounds(x=(ptb_data,), method = 'CROWN') + print(output_lb, output_ub) + def test_relu(self): self.create_test(act_func=torch.nn.functional.relu, low=-10, high=10, method='IBP') self.create_test(act_func=torch.nn.functional.relu, low=-10, high=10, method='CROWN') @@ -106,13 +125,22 @@ def test_tanh(self): def test_sin(self): self.create_test(act_func=torch.sin, low=-10, high=10, method='IBP') - # self.create_test(act_func=torch.sin, low=-10, high=10, method='CROWN') + self.create_test(act_func=torch.sin, low=-10, high=10, method='CROWN') def test_cos(self): self.create_test(act_func=torch.cos, low=-10, high=10, method='IBP') - # self.create_test(act_func=torch.cos, low=-10, high=10, method='CROWN') + self.create_test(act_func=torch.cos, low=-10, high=10, method='CROWN') + def test_arctan(self): + self.create_test(act_func=torch.arctan, low=-10, high=10, method='IBP') + self.create_test(act_func=torch.arctan, low=-10, high=10, method='CROWN') + + def test_tan(self): + # Test tan(x) in different periods. + for i in range(-5, 5): + self.create_test(act_func=torch.arctan, low=-0.5*torch.pi + i*torch.pi + 1e-20, high=0.5*torch.pi + i*torch.pi - 1e-20, method='IBP') + self.create_test(act_func=torch.arctan, low=-0.5*torch.pi + i*torch.pi + 1e-20, high=0.5*torch.pi + i*torch.pi - 1e-20, method='CROWN') if __name__ == '__main__': testcase = Test1DActivation() @@ -122,5 +150,5 @@ def test_cos(self): testcase.test_tanh() testcase.test_sin() testcase.test_cos() - - + testcase.test_arctan() + testcase.test_tan() diff --git a/tests/test_pooling.py b/tests/test_avgpool.py similarity index 96% rename from tests/test_pooling.py rename to tests/test_avgpool.py index bc5ec23..cf4b45b 100644 --- a/tests/test_pooling.py +++ b/tests/test_avgpool.py @@ -1,4 +1,4 @@ -# Test bounds on a 1 layer linear network. +"""Test average pooling.""" import torch.nn as nn from auto_LiRPA import BoundedModule, BoundedTensor diff --git a/tests/test_bound_ops.py b/tests/test_bound_ops.py index 5b3b157..a49adc4 100644 --- a/tests/test_bound_ops.py +++ b/tests/test_bound_ops.py @@ -1,10 +1,10 @@ """Test classes for bound operators""" import torch -import os from auto_LiRPA.bound_ops import * -from auto_LiRPA.utils import LinearBound +from auto_LiRPA.linear_bound import LinearBound from testcase import TestCase + """Dummy node for testing""" class Dummy: def __init__(self, lower, upper=None, perturbed=False): @@ -36,10 +36,9 @@ def test(self): dummy_bias = Dummy(bias) op = BoundLinear( - input_name=[None, None, None], - name=None, ori_name=None, attr=None, + attr={}, inputs=[dummy_in, dummy_weight, dummy_bias], - output_index=0, options={}, device=device) + output_index=0, options={}) # test `forward` data_out = op(data_in, weight, bias) diff --git a/tests/test_constant.py b/tests/test_constant.py index cc4a42e..756a125 100644 --- a/tests/test_constant.py +++ b/tests/test_constant.py @@ -57,4 +57,5 @@ def test(self): if __name__ == '__main__': # Change to generate=True when genearting reference results testcase = TestConstant(generate=False) + testcase.setUp() testcase.test() diff --git a/tests/test_conv.py b/tests/test_conv.py index 9390384..b9e571f 100644 --- a/tests/test_conv.py +++ b/tests/test_conv.py @@ -15,7 +15,7 @@ def forward(self, x): return x.view((x.shape[0], -1)) class cnn_model(nn.Module): - def __init__(self, layers, padding, stride): + def __init__(self, layers, padding, stride, linear=True): super(cnn_model, self).__init__() self.module_list = [] channel = 1 @@ -27,8 +27,9 @@ def __init__(self, layers, padding, stride): assert length > 0 self.module_list.append(nn.ReLU()) self.module_list.append(Flatten()) - self.module_list.append(nn.Linear(3 * length * length, 256)) - self.module_list.append(nn.Linear(256, 10)) + if linear: + self.module_list.append(nn.Linear(3 * length * length, 256)) + self.module_list.append(nn.Linear(256, 10)) self.model = nn.Sequential(*self.module_list) def forward(self, x): @@ -42,7 +43,7 @@ def __init__(self, methodName='runTest', generate=False): generate=generate) def test(self): - models = [2, 3] + models = [1, 2, 3] paddings = [1, 2] strides = [1, 3] @@ -54,26 +55,32 @@ def test(self): for layer_num in models: for padding in paddings: for stride in strides: - try: + for linear in [True, False]: model_ori = cnn_model(layer_num, padding, stride) - except: - continue + print('Model:', model_ori) - model = BoundedModule(model_ori, image, device="cpu", bound_opts={"conv_mode": "patches"}) - eps = 0.3 - norm = np.inf - ptb = PerturbationLpNorm(norm=norm, eps=eps) - image = BoundedTensor(image, ptb) - pred = model(image) - lb, ub = model.compute_bounds() + model = BoundedModule(model_ori, image, bound_opts={"conv_mode": "patches"}) + eps = 0.3 + norm = np.inf + ptb = PerturbationLpNorm(norm=norm, eps=eps) + image = BoundedTensor(image, ptb) + pred = model(image) + lb, ub = model.compute_bounds() - model = BoundedModule(model_ori, image, device="cpu", bound_opts={"conv_mode": "matrix"}) - pred = model(image) - lb_ref, ub_ref = model.compute_bounds() + model = BoundedModule(model_ori, image, bound_opts={"conv_mode": "matrix"}) + pred = model(image) + lb_ref, ub_ref = model.compute_bounds() - assert lb.shape == ub.shape == torch.Size((N, n_classes)) - self.assertEqual(lb, lb_ref) - self.assertEqual(ub, ub_ref) + if linear: + assert lb.shape == ub.shape == torch.Size((N, n_classes)) + self.assertEqual(lb, lb_ref) + self.assertEqual(ub, ub_ref) + + if not linear and layer_num == 1: + pred = model(image) + lb_forward, ub_forward = model.compute_bounds(method='forward') + self.assertEqual(lb, lb_forward) + self.assertEqual(ub, ub_forward) if __name__ == '__main__': testcase = TestConv() diff --git a/tests/test_examples.py b/tests/test_examples.py index 0fae3d7..f2de101 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,4 +1,7 @@ -"""Test all the examples before release""" +"""Test all the examples before release. + +This script is expected be manually run and is not used in automatic tests.""" + import pytest import subprocess import os @@ -8,7 +11,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--test', type=str, default=None) -args = parser.parse_args() +args = parser.parse_args() pytest_skip = pytest.mark.skip( reason="It should be tested on a GPU server and excluded from CI") @@ -24,14 +27,14 @@ def download_data_language(): url = "http://download.huan-zhang.com/datasets/language/data_language.tar.gz" if not os.path.exists('../examples/language/data/sst'): subprocess.run(shlex.split(f"wget {url}"), cwd="../examples/language") - subprocess.run(shlex.split(f"tar xvf data_language.tar.gz"), + subprocess.run(shlex.split(f"tar xvf data_language.tar.gz"), cwd="../examples/language") @pytest_skip def test_transformer(): cmd = f"""python train.py --dir {cache_dir} --robust --method IBP+backward_train --train --num_epochs 2 --num_epochs_all_nodes 2 - --eps_start 2 --eps_length 1 --eps 0.1""" + --eps_start 2 --eps_length 1 --eps 0.1""" print(cmd, file=sys.stderr) download_data_language() subprocess.run(shlex.split(cmd), cwd='../examples/language') @@ -47,19 +50,36 @@ def test_lstm(): download_data_language() subprocess.run(shlex.split(cmd), cwd='../examples/language') -#FIXME this is broken @pytest_skip def test_lstm_seq(): cmd = f"""python train.py --dir {cache_dir} --hidden_size 2 --num_epochs 2 --num_slices 4""" print(cmd, file=sys.stderr) - subprocess.run(shlex.split(cmd), cwd='../examples/sequence') + subprocess.run(shlex.split(cmd), cwd='../examples/sequence') @pytest_skip def test_simple_verification(): cmd = "python simple_verification.py" print(cmd, file=sys.stderr) - subprocess.run(shlex.split(cmd), cwd='../examples/vision') + subprocess.run(shlex.split(cmd), cwd='../examples/vision') + +@pytest_skip +def test_custom_op(): + cmd = "python custom_op.py" + print(cmd, file=sys.stderr) + subprocess.run(shlex.split(cmd), cwd='../examples/vision') + +@pytest_skip +def test_efficient_convolution(): + cmd = "python efficient_convolution.py" + print(cmd, file=sys.stderr) + subprocess.run(shlex.split(cmd), cwd='../examples/vision') + +@pytest_skip +def test_two_node(): + cmd = "python verify_two_node.py" + print(cmd, file=sys.stderr) + subprocess.run(shlex.split(cmd), cwd='../examples/vision') @pytest_skip def test_simple_training(): @@ -92,8 +112,8 @@ def test_tinyimagenet(): --in_planes 2 --widen_factor 2""" print(cmd, file=sys.stderr) if not os.path.exists('../examples/vision/data/tinyImageNet/tiny-imagenet-200'): - subprocess.run(shlex.split("bash tinyimagenet_download.sh"), - cwd="../examples/vision/data/tinyImageNet") + subprocess.run(shlex.split("bash tinyimagenet_download.sh"), + cwd="../examples/vision/data/tinyImageNet") subprocess.run(shlex.split(cmd), cwd='../examples/vision') @pytest_skip @@ -103,7 +123,7 @@ def test_imagenet(): --num_epochs 3 --scheduler_opts start=2,length=1 --eps {0.1/255} --in_planes 2 --widen_factor 2""" print(cmd) - if (not os.path.exists('../examples/vision/data/ImageNet64/train') or + if (not os.path.exists('../examples/vision/data/ImageNet64/train') or not os.path.exists('../examples/vision/data/ImageNet64/test')): print('Error: ImageNet64 dataset is not ready.') return -1 @@ -111,19 +131,18 @@ def test_imagenet(): @pytest_skip def test_release(): - if args.test: - # Only run a specified test - eval(f'test_{args.test}')() - else: - # Run all tests - test_simple_training() - test_transformer() - test_lstm() - test_lstm_seq() - test_simple_verification() - test_cifar_training() - test_weight_perturbation() - test_tinyimagenet() + """Run all tests.""" + test_simple_training() + test_transformer() + test_lstm() + test_lstm_seq() + test_simple_verification() + test_cifar_training() + test_weight_perturbation() + test_tinyimagenet() + test_custom_op() + test_efficient_convolution() + test_two_node() if __name__ == '__main__': test_release() diff --git a/tests/test_language_models.py b/tests/test_language_models.py index 72cf035..aa4f6aa 100644 --- a/tests/test_language_models.py +++ b/tests/test_language_models.py @@ -43,10 +43,14 @@ def train(): os.system("rm -rf ../examples/language/model_transformer_test") if os.path.exists("../examples/language/model_lstm_test"): os.system("rm -rf ../examples/language/model_lstm_test") - logger.info("Training a Transformer") + logger.info("\nTraining a Transformer") + print(cmd_transformer_train) + print() os.system(cmd_transformer_train) os.system("cp ../examples/language/model_transformer_test/ckpt_2 data/ckpt_transformer") - logger.info("Training an LSTM") + logger.info("\nTraining an LSTM") + print(cmd_lstm_train) + print() os.system(cmd_lstm_train) os.system("cp ../examples/language/model_lstm_test/ckpt_2 data/ckpt_lstm") @@ -55,10 +59,14 @@ def read_res(): return pickle.load(file) def evaluate(): - logger.info('Evaluating the trained LSTM') + logger.info('\nEvaluating the trained LSTM') + print(cmd_lstm_test) + print() os.system(cmd_lstm_test) res_lstm = read_res() - logger.info('Evaluating the trained Transformer') + logger.info('\nEvaluating the trained Transformer') + print(cmd_transformer_test) + print() os.system(cmd_transformer_test) res_transformer = read_res() os.system("rm {}".format(res_path)) diff --git a/tests/test_linear_cnn_model.py b/tests/test_linear_cnn_model.py index 975336b..eb0cb10 100644 --- a/tests/test_linear_cnn_model.py +++ b/tests/test_linear_cnn_model.py @@ -1,4 +1,4 @@ -# Test bounds on a 1 layer CNN network. +"""Test bounds on a 1 layer CNN network.""" import torch.nn as nn from auto_LiRPA import BoundedModule, BoundedTensor @@ -19,7 +19,7 @@ def forward(self, x): x = x.view(-1, input_dim //2 * input_dim // 2 * out_channel) return x -class TestLinearCNNModel(TestLinearModel): +class TestLinearCNNModel(TestLinearModel): def __init__(self, methodName='runTest', generate=False): super().__init__(methodName) self.original_model = LinearCNNModel() diff --git a/tests/test_linear_model.py b/tests/test_linear_model.py index 305d661..b762195 100644 --- a/tests/test_linear_model.py +++ b/tests/test_linear_model.py @@ -1,4 +1,4 @@ -# Test bounds on a 1 layer linear network. +"""Test bounds on a 1 layer linear network.""" import torch.nn as nn from auto_LiRPA import BoundedModule, BoundedTensor @@ -17,7 +17,7 @@ def forward(self, x): x = self.fc(x) return x -class TestLinearModel(TestCase): +class TestLinearModel(TestCase): def __init__(self, methodName='runTest', generate=False): super().__init__(methodName, seed=0) self.original_model = LinearModel() diff --git a/tests/test_maxpool.py b/tests/test_maxpool.py index 3b731d1..e77e15c 100644 --- a/tests/test_maxpool.py +++ b/tests/test_maxpool.py @@ -1,82 +1,120 @@ +"""Test max pooling.""" + import torch import os import torch.nn as nn import torch.nn.functional as F import torchvision from auto_LiRPA import BoundedModule, BoundedTensor -from auto_LiRPA.perturbations import * +from auto_LiRPA.perturbations import * +from auto_LiRPA.utils import Flatten from testcase import TestCase -class Flatten(nn.Module): - def __init__(self): - super(Flatten, self).__init__() - - def forward(self, x): - return x.view((x.shape[0], -1)) +def MadryCNN(): + return nn.Sequential( + nn.Conv2d(1, 32, 5, stride=1, padding=2), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + nn.Conv2d(32, 64, 5, stride=1, padding=2), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + Flatten(), + nn.Linear(64*7*7,1024), + nn.ReLU(), + nn.Linear(1024, 10) + ) class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - self.n_n_conv2d = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 32, 'padding': [0, 0], 'kernel_size': (2, 2), 'stride': [1, 1], 'in_channels': 1, 'bias': True}) - self.n_n_average_pooling2d = nn.MaxPool2d(**{'kernel_size': [4, 4], 'ceil_mode': False, 'stride': [4, 4], 'padding': [0, 0]}) - self.n_n_flatten_Flatten = nn.Flatten(**{'start_dim': 1}) - self.n_n_dense = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 10, 'padding': [0, 0], 'kernel_size': (1, 1), 'stride': [1, 1], 'in_channels': 1152, 'bias': True}) - self.n_n_activation_Flatten = nn.Flatten(**{'start_dim': 1}) - - def forward(self, *inputs): - t_ImageInputLayer, = inputs - t_conv2d = self.n_n_conv2d(t_ImageInputLayer) - t_conv2d_relu = F.relu(t_conv2d) - t_average_pooling2d = self.n_n_average_pooling2d(t_conv2d_relu)[:, :, :, :] - t_flatten_Transpose = t_average_pooling2d.permute(*[0, 2, 3, 1]) - t_flatten_Flatten = self.n_n_flatten_Flatten(t_flatten_Transpose) - t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Flatten, 2) - t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Unsqueeze, 3) - t_dense = self.n_n_dense(t_flatten_Unsqueeze) - t_activation_Flatten = self.n_n_activation_Flatten(t_dense) - return t_activation_Flatten - -class TestConv(TestCase): + def __init__(self, kernel_size=4, stride=4, padding=0, conv_padding=0): + super(Model, self).__init__() + self.n_n_conv2d = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 1, 'padding': [0, 0], 'kernel_size': (2, 2), 'stride': [1, 1], 'in_channels': 1, 'bias': True}) + self.n_n_maxpool = nn.MaxPool2d(**{'kernel_size': [kernel_size, kernel_size], 'ceil_mode': False, 'stride': [stride, stride], 'padding': [padding, padding]}) + self.n_n_conv2d_2 = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 8, 'padding': [conv_padding, conv_padding], 'kernel_size': (2, 2), 'stride': [1, 1], 'in_channels': 1, 'bias': True}) + self.n_n_maxpool_2 = nn.MaxPool2d(**{'kernel_size': [kernel_size, kernel_size], 'ceil_mode': False, 'stride': [stride, stride], 'padding': [padding, padding]}) + self.n_n_flatten_Flatten = nn.Flatten(**{'start_dim': 1}) + + self.n_n_dense = None + + self.n_n_activation_Flatten = nn.Flatten(**{'start_dim': 1}) + + def forward(self, *inputs): + t_ImageInputLayer, = inputs + t_conv2d = self.n_n_conv2d(t_ImageInputLayer) + t_conv2d_relu = F.relu(t_conv2d) + t_maxpool = self.n_n_maxpool(t_conv2d_relu)[:, :, :, :] + t_conv2d_max = self.n_n_conv2d_2(t_maxpool) + t_conv2d_max = F.relu(t_conv2d_max) + t_maxpool_2 = self.n_n_maxpool_2(t_conv2d_max) + t_flatten_Transpose = t_maxpool_2.permute(*[0, 2, 3, 1]) + t_flatten_Flatten = self.n_n_flatten_Flatten(t_flatten_Transpose) + t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Flatten, 2) + t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Unsqueeze, 3) + + if self.n_n_dense is None: + self.n_n_dense = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 10, 'padding': [0, 0], 'kernel_size': (1, 1), 'stride': [1, 1], 'in_channels': t_flatten_Unsqueeze.shape[1], 'bias': True}) + t_dense = self.n_n_dense(t_flatten_Unsqueeze) + t_activation_Flatten = self.n_n_activation_Flatten(t_dense) + + return t_activation_Flatten + +class TestMaxPool(TestCase): def __init__(self, methodName='runTest', generate=False): - super().__init__(methodName, + super().__init__(methodName, seed=1, ref_path=None, generate=generate) def test(self): np.random.seed(123) - models = [2, 3] - paddings = [1, 2] - strides = [1, 3] - - model_ori = Model() - data = torch.load('data/maxpool_test_data') - model_ori.load_state_dict(data['model']) N = 2 - n_classes = 10 - image = data['input'] - # image = torch.rand([N,1,28,28]) - # image = image.to(torch.float32) / 255.0 - model = BoundedModule(model_ori, image, device="cpu", bound_opts={"conv_mode": "matrix"}) - eps = 0.3 - norm = np.inf - ptb = PerturbationLpNorm(norm=norm, eps=eps) - image = BoundedTensor(image, ptb) - pred = model(image) - lb, ub = model.compute_bounds() + for kernel_size in [3,4]: + for padding in [0]: + for conv_padding in [0,1]: + print(kernel_size, padding, kernel_size, conv_padding) + + model_ori = Model(kernel_size=kernel_size, padding=padding, stride=kernel_size, conv_padding=conv_padding) + if not self.generate: + data = torch.load('data/maxpool_test_data_{}-{}-{}-{}'.format(kernel_size, padding, kernel_size, conv_padding)) + image = data['input'] + model_ori(image) + model_ori.load_state_dict(data['model']) + else: + image = torch.rand([N, 1, 28, 28]) + model_ori(image) + + + if self.generate: + conv_mode = "matrix" + else: + conv_mode = "patches" + + model = BoundedModule(model_ori, image, device="cpu", bound_opts={"conv_mode": conv_mode}) + eps = 0.3 + norm = np.inf + ptb = PerturbationLpNorm(norm=norm, eps=eps) + image = BoundedTensor(image, ptb) + pred = model(image) + + lb, ub = model.compute_bounds() - lb_ref = data['lb'] - ub_ref = data['ub'] - assert torch.allclose(lb, lb_ref, 1e-4) - assert torch.allclose(ub, ub_ref, 1e-4) + if self.generate: + torch.save( + {'model': model_ori.state_dict(), + 'input': image, + 'lb': lb, + 'ub': ub}, 'data/maxpool_test_data_{}-{}-{}-{}'.format(kernel_size, padding, kernel_size, conv_padding) + ) - # lb, ub = model.compute_bounds(x=(image,), method="CROWN-Optimized") + if not self.generate: + lb_ref = data['lb'] + ub_ref = data['ub'] - # torch.save({'input': image, 'model': model_ori.state_dict(), 'lb': lb, 'ub': ub}, 'data/maxpool_test_data') + assert torch.allclose(lb, lb_ref, 1e-4) + assert torch.allclose(ub, ub_ref, 1e-4) if __name__ == '__main__': - testcase = TestConv() + testcase = TestMaxPool(generate=True) testcase.test() diff --git a/tests/test_relative_ibp.py b/tests/test_relative_ibp.py deleted file mode 100644 index 9eb99ff..0000000 --- a/tests/test_relative_ibp.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Test IBP with relative bounds""" -import torch -import os -import torch.nn as nn -import torch.nn.functional as F -from auto_LiRPA import BoundedModule, BoundedTensor -from auto_LiRPA.perturbations import * -from auto_LiRPA.bound_ops import * -from testcase import TestCase -import sys -sys.path.append('../examples/vision') -from models import * - -class TestRelativeIBP(TestCase): - def __init__(self, methodName='runTest', generate=False): - super().__init__(methodName) - - def test(self): - dummy_input = torch.randn(1, 3, 32, 32) - - model_ori = cnn_6layer(in_ch=3, in_dim=32) - model = BoundedModule(model_ori, dummy_input, bound_opts={ 'ibp_relative': True }) - - model_ori_ref = cnn_6layer(in_ch=3, in_dim=32) - model_ori_ref.load_state_dict(model_ori.state_dict()) - model_ref = BoundedModule(model_ori_ref, dummy_input, bound_opts={ 'ibp_relative': False }) - - eps = 1e-1 - data = torch.randn(8, 3, 32, 32) - data_lb, data_ub = data - eps, data + eps - ptb = PerturbationLpNorm(norm=np.inf, eps=eps, x_L=data_lb, x_U=data_ub, relative=True) - x = (BoundedTensor(data, ptb),) - - fv = model(x) - fv_ref = model_ref(x) - lb, ub = model.compute_bounds(method='IBP') - lb_ref, ub_ref = model_ref.compute_bounds(method='IBP') - - self.assertEqual(lb, lb_ref) - self.assertEqual(ub, ub_ref) diff --git a/tests/test_simple_verification.py b/tests/test_simple_verification.py new file mode 100644 index 0000000..f648611 --- /dev/null +++ b/tests/test_simple_verification.py @@ -0,0 +1,55 @@ +"""Test optimized bounds in simple_verification.""" +import torch +import torch.nn as nn +import torchvision +from auto_LiRPA import BoundedModule, BoundedTensor +from auto_LiRPA.perturbations import PerturbationLpNorm +from auto_LiRPA.utils import Flatten +from testcase import TestCase + +# This simple model comes from https://github.com/locuslab/convex_adversarial +def mnist_model(): + model = nn.Sequential( + nn.Conv2d(1, 16, 4, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(16, 32, 4, stride=2, padding=1), + nn.ReLU(), + Flatten(), + nn.Linear(32*7*7,100), + nn.ReLU(), + nn.Linear(100, 10) + ) + return model + +class TestSimpleVerification(TestCase): + def __init__(self, methodName='runTest'): + super().__init__(methodName) + + def test(self): + model = mnist_model() + checkpoint = torch.load( + '../examples/vision/pretrain/mnist_a_adv.pth', + map_location=torch.device('cpu')) + model.load_state_dict(checkpoint) + + test_data = torchvision.datasets.MNIST( + './data', train=False, download=True, transform=torchvision.transforms.ToTensor()) + N = 2 + image = test_data.data[:N].view(N,1,28,28) + image = image.to(torch.float32) / 255.0 + if torch.cuda.is_available(): + image = image.cuda() + model = model.cuda() + + lirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device) + ptb = PerturbationLpNorm(0.3) + image = BoundedTensor(image, ptb) + + method = 'CROWN-Optimized (alpha-CROWN)' + lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}}) + _, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0]) + self.assertTensorEqual(ub[0][7], torch.tensor(12.5080)) + +if __name__ == '__main__': + testcase = TestSimpleVerification() + testcase.test() diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py index 4ac7b25..7a5019a 100644 --- a/tests/test_vision_models.py +++ b/tests/test_vision_models.py @@ -1,4 +1,5 @@ -import random +import torch +import torch.nn as nn import torch.nn.functional as F from auto_LiRPA import BoundedModule, BoundedTensor from auto_LiRPA.perturbations import * @@ -23,26 +24,28 @@ def forward(self, x): return x -class TestVisionModels(TestCase): +class TestVisionModels(TestCase): def __init__(self, methodName='runTest', generate=False): - super().__init__(methodName, seed=1234, ref_path='data/vision_test_data') + super().__init__(methodName, seed=1234, + ref_path='data/vision_test_data', generate=generate) self.result = {} def verify_bounds(self, model, x, IBP, method, forward_ret, lb_name, ub_name): lb, ub = model(method_opt="compute_bounds", x=(x,), IBP=IBP, method=method) self.result[lb_name] = lb self.result[ub_name] = ub - assert torch.allclose(lb, self.reference[lb_name], 1e-4), (lb - self.reference[lb_name]).abs().sum() - assert torch.allclose(ub, self.reference[ub_name], 1e-4), (ub - self.reference[ub_name]).abs().sum() - assert ((lb - self.reference[lb_name]).pow(2).sum() < 1e-9), (lb - self.reference[lb_name]).pow(2).sum() - assert ((ub - self.reference[ub_name]).pow(2).sum() < 1e-9), (ub - self.reference[ub_name]).pow(2).sum() # test gradient backward propagation loss = (ub - lb).abs().sum() loss.backward() grad = x.grad - self.result[lb_name[:-2] + 'grad'] = grad - assert torch.allclose(grad, self.reference[lb_name[:-2] + 'grad'], 1e-4, 1e-6) - assert (grad - self.reference[lb_name[:-2] + 'grad']).pow(2).sum() < 1e-9 + self.result[lb_name[:-2] + 'grad'] = grad.clone() + if not self.generate: + assert torch.allclose(lb, self.reference[lb_name], 1e-4), (lb - self.reference[lb_name]).abs().sum() + assert torch.allclose(ub, self.reference[ub_name], 1e-4), (ub - self.reference[ub_name]).abs().sum() + assert ((lb - self.reference[lb_name]).pow(2).sum() < 1e-9), (lb - self.reference[lb_name]).pow(2).sum() + assert ((ub - self.reference[ub_name]).pow(2).sum() < 1e-9), (ub - self.reference[ub_name]).pow(2).sum() + assert torch.allclose(grad, self.reference[lb_name[:-2] + 'grad'], 1e-4, 1e-6) + assert (grad - self.reference[lb_name[:-2] + 'grad']).pow(2).sum() < 1e-9 def test_bounds(self): np.random.seed(123) # FIXME inconsistent seeds @@ -83,6 +86,13 @@ def test_bounds(self): ub_name='l_2_CROWN_ub') # CROWN if self.generate: - for item in self.result: - self.reference = self.result[item] + self.result['data'] = self.reference['data'] + self.result['model'] = self.reference['model'] self.save() + + +if __name__ =="__main__": + # t = TestVisionModels(generate=True) + t = TestVisionModels() + t.setUp() + t.test_bounds() \ No newline at end of file diff --git a/tests/test_weight_perturbation.py b/tests/test_weight_perturbation.py index 073f7b4..3cc7ba0 100644 --- a/tests/test_weight_perturbation.py +++ b/tests/test_weight_perturbation.py @@ -1,28 +1,26 @@ import copy -import random -import argparse -import torch.nn.functional as F import subprocess import numpy as np from testcase import TestCase import sys sys.path.append('../examples/vision') import models -from auto_LiRPA import BoundedModule, BoundedParameter +from auto_LiRPA import BoundedModule from auto_LiRPA.perturbations import * -class TestWeightPerturbation(TestCase): +class TestWeightPerturbation(TestCase): def __init__(self, methodName='runTest', generate=False): super().__init__(methodName, seed=1234, ref_path='data/weight_perturbation_test_data') self.result = {} def test_training(self): + # python weight_perturbation_training.py --device cpu --scheduler_opts start=1,length=100 --num_epochs 1 --truncate_data 5 ret = subprocess.run( - ['python', 'weight_perturbation_training.py', + ['python', 'weight_perturbation_training.py', '--device', 'cpu', '--scheduler_opts', 'start=1,length=100', - '--num_epochs', '1', + '--num_epochs', '1', '--truncate_data', '5'], cwd='../examples/vision', capture_output=True) self.assertEqual(ret.returncode, 0, ret.stderr) @@ -56,11 +54,12 @@ def test_perturbation(self): self.result['model'] = model_ori.state_dict() self.result['data'] = torch.randn(8, 1, 28, 28) model_ori.load_state_dict(self.result['model']) - state_dict = copy.deepcopy(model_ori.state_dict()) + state_dict = copy.deepcopy(model_ori.state_dict()) dummy_input = self.result['data'].requires_grad_() inputs = (dummy_input,) - model = BoundedModule(model_ori, inputs) + model = BoundedModule(model_ori, inputs, bound_opts={ + 'sparse_intermediate_bounds': False, 'sparse_conv_intermediate_bounds': False, 'sparse_intermediate_bounds_with_ibp': False}) forward_ret = model(dummy_input) model_ori.eval() @@ -69,7 +68,8 @@ def test_perturbation(self): def verify_model(pert_weight=True, pert_bias=True, norm=np.inf, lb_name='', ub_name=''): model_ori_ = models.Models['mlp_3layer_weight_perturb'](pert_weight=pert_weight, pert_bias=pert_bias, norm=norm).eval() model_ori_.load_state_dict(state_dict) - model_ = BoundedModule(model_ori_, inputs) + model_ = BoundedModule(model_ori_, inputs, bound_opts={ + 'sparse_intermediate_bounds': False, 'sparse_conv_intermediate_bounds': False, 'sparse_intermediate_bounds_with_ibp': False}) model_.ptb = model_ori.ptb self.verify_bounds(model_, dummy_input, IBP=True, method='backward', forward_ret=forward_ret,