diff --git a/README.md b/README.md index 0d8bd96..fcae89e 100644 --- a/README.md +++ b/README.md @@ -14,13 +14,12 @@ ## What's New? - Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) **won** [VNN-COMP 2021](https://sites.google.com/view/vnn2021) **with the highest total score**, outperforming 11 SOTA verifiers. α,β-CROWN uses the `auto_LiRPA` library as its core bound computation library. +- Support for [custom operators](https://auto-lirpa.readthedocs.io/en/latest/custom_op.html). (01/02/2022) - [Optimized CROWN/LiRPA](https://arxiv.org/pdf/2011.13824.pdf) bound (α-CROWN) for ReLU, **sigmoid**, **tanh**, and **maxpool** activation functions, which can significantly outperform regular CROWN bounds. See [simple_verification.py](examples/vision/simple_verification.py#L59) for an example. (07/31/2021) - Handle split constraints for ReLU neurons ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf)) for complete verifiers. (07/31/2021) - A memory efficient GPU implementation of backward (CROWN) bounds for convolutional layers. (10/31/2020) -- Certified defense models for downscaled -[ImageNet](#imagenet-pretrained), [TinyImageNet](#imagenet-pretrained), [CIFAR-10](#cifar10-pretrained), -and [LSTM/Transformers](#language-pretrained). (08/20/2020) +- Certified defense models for downscaled ImageNet, TinyImageNet, CIFAR-10, LSTM/Transformer. (08/20/2020) - Adding support to **complex vision models** including DenseNet, ResNeXt and WideResNet. (06/30/2020) - **Loss fusion**, a technique that reduces training cost of tight LiRPA bounds (e.g. CROWN-IBP) to the same asympototic complexity of IBP, making LiRPA based certified @@ -143,26 +142,23 @@ obtaining gradients through autodiff. Bounds are efficiently computed on GPUs. ## More Working Examples -We provide a wide range of examples of using `auto_LiRPA`: +We provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA`: -* [Basic Bound Computation and **Robustness Verification** of Neural Networks](doc/examples.md#basic-bound-computation-and-robustness-verification-of-neural-networks) -* [Basic **Certified Adversarial Defense** Training](doc/examples.md#basic-certified-adversarial-defense-training) -* [Large-scale Certified Defense Training on **ImageNet**](doc/examples.md#certified-adversarial-defense-on-downscaled-imagenet-and-tinyimagenet-with-loss-fusion) -* [Certified Adversarial Defense Training on Sequence Data with **LSTM**](doc/examples.md#certified-adversarial-defense-training-for-lstm-on-mnist) -* [Certifiably Robust Language Classifier using **Transformers**](doc/examples.md#certifiably-robust-language-classifier-with-transformer-and-lstm) -* [Certified Robustness against **Model Weight Perturbations**](doc/examples.md#certified-robustness-against-model-weight-perturbations-and-certified-defense) +* [Basic Bound Computation and **Robustness Verification** of Neural Networks](doc/src/examples.md#basic-bound-computation-and-robustness-verification-of-neural-networks) +* [Basic **Certified Adversarial Defense** Training](doc/src/examples.md#basic-certified-adversarial-defense-training) +* [Large-scale Certified Defense Training on **ImageNet**](doc/src/examples.md#certified-adversarial-defense-on-downscaled-imagenet-and-tinyimagenet-with-loss-fusion) +* [Certified Adversarial Defense Training on Sequence Data with **LSTM**](doc/src/examples.md#certified-adversarial-defense-training-for-lstm-on-mnist) +* [Certifiably Robust Language Classifier using **Transformers**](doc/src/examples.md#certifiably-robust-language-classifier-with-transformer-and-lstm) +* [Certified Robustness against **Model Weight Perturbations**](doc/src/examples.md#certified-robustness-against-model-weight-perturbations-and-certified-defense) ## Full Documentations For more documentations, please refer to: * [Documentation homepage](https://auto-lirpa.readthedocs.io) - * [API documentation](https://auto-lirpa.readthedocs.io/en/latest/api.html) - -* [Adding custom operators](doc/custom_op.md) - -* [Guide](doc/paper.md) for reproducing [our NeurIPS 2020 paper](https://arxiv.org/abs/2002.12920) +* [Adding custom operators](https://auto-lirpa.readthedocs.io/en/latest/custom_op.html) +* [Guide](https://auto-lirpa.readthedocs.io/en/latest/paper.html) for reproducing [our NeurIPS 2020 paper](https://arxiv.org/abs/2002.12920) ## Publications diff --git a/auto_LiRPA/__init__.py b/auto_LiRPA/__init__.py index af094d9..f0b5922 100644 --- a/auto_LiRPA/__init__.py +++ b/auto_LiRPA/__init__.py @@ -2,5 +2,6 @@ from .bounded_tensor import BoundedTensor, BoundedParameter from .perturbations import PerturbationLpNorm, PerturbationSynonym from .wrapper import CrossEntropyWrapper, CrossEntropyWrapperMultiInput +from .bound_op_map import register_custom_op, unregister_custom_op __version__ = '0.2' \ No newline at end of file diff --git a/auto_LiRPA/bound_op_map.py b/auto_LiRPA/bound_op_map.py index b161d5f..087dab3 100644 --- a/auto_LiRPA/bound_op_map.py +++ b/auto_LiRPA/bound_op_map.py @@ -4,3 +4,21 @@ 'onnx::Gemm': BoundLinear, 'prim::Constant': BoundPrimConstant, } + +def register_custom_op(op_name: str, bound_obj: Bound) -> None: + """ Register a custom operator. + + Args: + op_name (str): Name of the custom operator + + bound_obj (Bound): The corresponding Bound class for the operator. + """ + bound_op_map[op_name] = bound_obj + +def unregister_custom_op(op_name: str) -> None: + """ Unregister a custom operator. + + Args: + op_name (str): Name of the custom operator + """ + bound_op_map.pop(op_name) \ No newline at end of file diff --git a/auto_LiRPA/operators/convolution.py b/auto_LiRPA/operators/convolution.py index 47ce9e9..6bc5d3f 100644 --- a/auto_LiRPA/operators/convolution.py +++ b/auto_LiRPA/operators/convolution.py @@ -490,8 +490,8 @@ def forward(self, x): def bound_backward(self, last_lA, last_uA, x): H, W = self.input_shape[-2], self.input_shape[-1] - lA = last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W) - uA = last_uA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W) + lA = (last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)) if last_lA is not None else None + uA = (last_uA.expand(list(last_uA.shape[:-2]) + [H, W]) / (H * W)) if last_uA is not None else None return [(lA, uA)], 0, 0 diff --git a/doc/.gitignore b/doc/.gitignore index 91bd57a..61f3fbd 100644 --- a/doc/.gitignore +++ b/doc/.gitignore @@ -1,2 +1,5 @@ _build -sections \ No newline at end of file +sections +*.md +!src/*.md +!README.md \ No newline at end of file diff --git a/doc/README.md b/doc/README.md index e729255..48a25fd 100644 --- a/doc/README.md +++ b/doc/README.md @@ -1,5 +1,8 @@ + # Documentation +This directory contains source files for building our documentation. + ## Dependencies Install additional libraries for building documentations: @@ -17,4 +20,3 @@ make html ``` The documentation will be generated at `_build/html`. - diff --git a/doc/api.rst b/doc/api.rst index 5375cb5..f70bf4b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -18,6 +18,8 @@ API Usage .. autofunction:: auto_LiRPA.perturbations.Perturbation.concretize .. autofunction:: auto_LiRPA.perturbations.Perturbation.init +.. autofunction:: auto_LiRPA.bound_op_map.register_custom_op + Indices and tables ------------------- diff --git a/doc/conf.py b/doc/conf.py index 9658736..04976ac 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -12,10 +12,13 @@ # import os import subprocess -# import sys -# sys.path.insert(0, os.path.abspath('.')) +import inspect +import sys +from pygit2 import Repository +sys.path.insert(0, '..') +import auto_LiRPA -subprocess.run(['python', 'parse_readme.py']) +subprocess.run(['python', 'process.py']) # -- Project information ----------------------------------------------------- @@ -31,6 +34,7 @@ # ones. extensions = [ 'sphinx.ext.autodoc', + 'sphinx.ext.linkcode', 'm2r2', ] @@ -40,8 +44,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - +exclude_patterns = ['_build', 'src', 'Thumbs.db', '.DS_Store'] # -- Options for HTML output ------------------------------------------------- @@ -54,3 +57,27 @@ # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] + +repo = Repository('../') +branch = repo.head.shorthand + +# Resolve function for the linkcode extension. +def linkcode_resolve(domain, info): + def find_source(): + obj = auto_LiRPA + parts = info['fullname'].split('.') + if info['module'].endswith(f'.{parts[0]}'): + module = info['module'][:-len(parts[0])-1] + else: + module = info['module'] + obj = sys.modules[module] + for part in parts: + obj = getattr(obj, part) + fn = inspect.getsourcefile(obj) + source, lineno = inspect.getsourcelines(obj) + return fn, lineno, lineno + len(source) - 1 + + fn, lineno_start, lineno_end = find_source() + filename = f'{fn}#L{lineno_start}-L{lineno_end}' + + return f"https://github.com/KaidiXu/auto_LiRPA/blob/{branch}/doc/{filename}" diff --git a/doc/parse_readme.py b/doc/parse_readme.py deleted file mode 100644 index 081a9be..0000000 --- a/doc/parse_readme.py +++ /dev/null @@ -1,25 +0,0 @@ -import re -import os - -heading = '' -copied = {} -print('Parsing markdown sections from README:') -with open('../README.md') as file: - for line in file.readlines(): - if line.startswith('##'): - heading = line[2:].strip() - else: - if not heading in copied: - copied[heading] = '' - copied[heading] += line -if not os.path.exists('sections'): - os.makedirs('sections') -for key in copied: - if key == '': - continue - filename = re.sub(r"[?+\'\"]", '', key.lower()) - filename = re.sub(r" ", '-', filename) + '.md' - print(filename) - with open(os.path.join('sections', filename), 'w') as file: - file.write(f'## {key}\n') - file.write(copied[key]) \ No newline at end of file diff --git a/doc/process.py b/doc/process.py new file mode 100644 index 0000000..93ffc48 --- /dev/null +++ b/doc/process.py @@ -0,0 +1,65 @@ +""" Process source files before running Sphinx""" +import re +import os +import shutil +from pygit2 import Repository + +repo = 'https://github.com/KaidiXu/auto_LiRPA' +branch = os.environ.get('BRANCH', None) or Repository('.').head.shorthand +repo_file_path = os.path.join(repo, 'tree', branch) + +""" Parse README.md into sections which can be reused """ +heading = '' +copied = {} +print('Parsing markdown sections from README:') +with open('../README.md') as file: + for line in file.readlines(): + if line.startswith('##'): + heading = line[2:].strip() + else: + if not heading in copied: + copied[heading] = '' + copied[heading] += line +if not os.path.exists('sections'): + os.makedirs('sections') +for key in copied: + if key == '': + continue + filename = re.sub(r"[?+\'\"]", '', key.lower()) + filename = re.sub(r" ", '-', filename) + '.md' + print(filename) + with open(os.path.join('sections', filename), 'w') as file: + file.write(f'## {key}\n') + file.write(copied[key]) +print() + +""" Load source files from src/ and fix links to GitHub """ +for filename in os.listdir('src'): + print(f'Processing {filename}') + with open(os.path.join('src', filename)) as file: + source = file.read() + source_new = '' + ptr = 0 + # res = re.findall('\[.*\]\(.*\)', source) + for m in re.finditer('(\[.*\])(\(.*\))', source): + assert m.start() >= ptr + source_new += source[ptr:m.start()] + ptr = m.start() + source_new += m.group(1) + ptr += len(m.group(1)) + link_raw = m.group(2) + while len(link_raw) >= 2 and link_raw[-2] == ')': + link_raw = link_raw[:-1] + link = link_raw[1:-1] + if link.startswith('https://') or link.startswith('http://') or '.html#' in link: + print(f'Skip link {link}') + link_new = link + else: + link_new = os.path.join(repo_file_path, 'docs/src', link) + print(f'Fix link {link} -> {link_new}') + source_new += f'({link_new})' + ptr += len(link_raw) + source_new += source[ptr:] + with open(filename, 'w') as file: + file.write(source_new) + print() \ No newline at end of file diff --git a/doc/requirements.txt b/doc/requirements.txt index 3f99d51..671ba14 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,3 +1,4 @@ sphinx>=4.1.2 docutils>=0.16 -m2r2>=0.3.1 \ No newline at end of file +m2r2>=0.3.1 +pygit2>=1.7.2 \ No newline at end of file diff --git a/doc/bound_opts.md b/doc/src/bound_opts.md similarity index 100% rename from doc/bound_opts.md rename to doc/src/bound_opts.md diff --git a/doc/custom_op.md b/doc/src/custom_op.md similarity index 63% rename from doc/custom_op.md rename to doc/src/custom_op.md index ffd990f..853aa48 100644 --- a/doc/custom_op.md +++ b/doc/src/custom_op.md @@ -2,14 +2,22 @@ In this documentation, we introduce how users can define custom operators (such as other activations) that are not currently supported in auto_LiRPA, with bound propagation methods. -## Write an Operator +## Write a Custom Operator There are three steps to write an operator: -1. Define a `torch.autograd.Function` (or `Function` for short) class, wrap the computation of the operator into this `Function`, and also define a symbolic method so that the operator can be parsed in auto_LiRPA via ONNX. Please refer to [PyTorch documentation](https://pytorch.org/docs/stable/onnx.html?highlight=symbolic#static-symbolic-method) on defining a `Function` with a symbolic method. Call this `Function` via `.apply()` when using this operator in the model. +1. Define a `torch.autograd.Function` (or `Function` for short) class, wrap the computation of the operator into this `Function`, and also define a symbolic method so that the operator can be parsed in auto_LiRPA via ONNX. Please refer to [PyTorch documentation](https://pytorch.org/docs/stable/onnx.html?highlight=symbolic#static-symbolic-method) on defining a `Function` with a symbolic method. -2. Implement a [Bound class](api.html#auto_LiRPA.bound_ops.Bound) to support bound propagation methods for this operator. -3. Create a mapping from the operator name (defined in step 1) to the bound class (defined in step 2). Define a `dict` which each item is a mapping. Pass the `dict` to the `custom_ops` argument when calling `BoundedModule` (see the [documentation](api.html#auto_LiRPA.BoundedModule)). For example, if the operator name is `MyRelu`, and the bound class is `BoundMyRelu`, then add `"MyRelu": BoundMyRelu` to the `dict`. +2. Create a `torch.nn.Module` which uses the defined operator. Call the operator via +`.apply()` of `Function`. + +3. Implement a [Bound class](api.html#auto_LiRPA.bound_ops.Bound) to support bound propagation methods for this operator. + +4. [Register the custom operator](api.html#auto_LiRPA.bound_op_map.register_custom_op). + +## Example + +We provide an [code example](../../examples/vision/custom_op.py) of using a custom operator called "PlusConstant". ## Contributing to the Library diff --git a/doc/examples.md b/doc/src/examples.md similarity index 91% rename from doc/examples.md rename to doc/src/examples.md index d44e89c..0f3624c 100644 --- a/doc/examples.md +++ b/doc/src/examples.md @@ -1,6 +1,6 @@ # Examples -We provide many [examples](examples) of using our `auto_LiRPA` library, +We provide many [examples](../../examples) of using our `auto_LiRPA` library, including robustness verification and certified robust training for fairly complicated networks and specifications. Please first install required libraries to run the examples: @@ -13,7 +13,7 @@ pip install -r requirements.txt ## Basic Bound Computation and Robustness Verification of Neural Networks We provide a very simple tutorial for `auto_LiRPA` at -[examples/vision/simple_verification.py](examples/vision/simple_verification.py). +[examples/vision/simple_verification.py](../../examples/vision/simple_verification.py). This script is self-contained. It loads a simple CNN model and compute the guaranteed lower and upper bounds using LiRPA for each output neuron under a L infinity perturbation. @@ -34,7 +34,7 @@ can be obtained using α-CROWN within a few seconds. ## Basic Certified Adversarial Defense Training We provide a [simple example of certified -training](examples/vision/simple_training.py). By default it uses +training](../../examples/vision/simple_training.py). By default it uses [CROWN-IBP](https://arxiv.org/pdf/1906.06316.pdf) to train a certifiably robust model: @@ -59,17 +59,17 @@ python simple_training.py --model mlp_3layer --norm 0 --eps 1 ``` For CIFAR-10, we provided some sample models in `examples/vision/models`: -e.g., [cnn_7layer_bn](./examples/vision/models/feedforward.py), -[DenseNet](./examples/vision/models/densenet.py), -[ResNet18](./examples/vision/models/resnet18.py), -[ResNeXt](./examples/vision/models/resnext.py). For example, to train a ResNeXt model on CIFAR, +e.g., [cnn_7layer_bn](../../examples/vision/models/feedforward.py), +[DenseNet](../../examples/vision/models/densenet.py), +[ResNet18](../../examples/vision/models/resnet18.py), +[ResNeXt](../../examples/vision/models/resnext.py). For example, to train a ResNeXt model on CIFAR, use: ```bash python cifar_training.py --batch_size 256 --model ResNeXt_cifar ``` -See a list of supported models [here](./examples/vision/models/__init__.py). +See a list of supported models [here](../../examples/vision/models/__init__.py). This command uses multi-GPUs by default. You probably need to reduce batch size if you have only 1 GPU. The CIFAR training implementation includes **loss fusion**, a technique that can greatly reduce training time and memory usage of @@ -85,7 +85,7 @@ python cifar_training.py --verify --model cnn_7layer_bn --load saved_models/cnn ``` More example of CIFAR-10 training can be found -in [doc/paper.md](doc/paper.md). +in [doc/paper.md](paper.md). ## Certified Adversarial Defense on Downscaled ImageNet and TinyImageNet with Loss Fusion @@ -139,12 +139,12 @@ MODEL=saved_models/wide_resnet_imagenet64_1000 python imagenet_training.py --verify --model wide_resnet_imagenet64_1000class --load $MODEL --eps 0.003921568627451 ``` -See more details in [doc/paper.md](doc/paper.md) for these examples. +See more details in [paper.md](paper.md) for these examples. ## Certified Adversarial Defense Training for LSTM on MNIST -In [examples/sequence](examples/sequence), we have an example of training a +In [examples/sequence](../../examples/sequence), we have an example of training a certifiably robust LSTM on MNIST, where an input image is perturbed within an Lp-ball and sliced to several pieces each regarded as an input frame. To run the example: @@ -156,7 +156,7 @@ python train.py ## Certifiably Robust Language Classifier with Transformer and LSTM -In [examples/language](examples/language), we show that our framework can +In [examples/language](../../examples/language), we show that our framework can support perturbation specification of word substitution, beyond Lp-ball perturbation. We perform certified training for Transformer and LSTM on a sentiment classification task. diff --git a/doc/paper.md b/doc/src/paper.md similarity index 100% rename from doc/paper.md rename to doc/src/paper.md diff --git a/examples/.gitignore b/examples/.gitignore deleted file mode 100644 index 9e67bd4..0000000 --- a/examples/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -data - diff --git a/examples/vision/custom_op.py b/examples/vision/custom_op.py new file mode 100644 index 0000000..1d5d96e --- /dev/null +++ b/examples/vision/custom_op.py @@ -0,0 +1,125 @@ +""" A example for custom operators. + +In this example, we create a custom operator called "PlusConstant", which can +be written as "f(x) = x + c" for some constant "c" (an attribute of the operator). +""" +import torch +import torch.nn as nn +import torchvision +from auto_LiRPA import BoundedModule, BoundedTensor, register_custom_op +from auto_LiRPA.operators import Bound +from auto_LiRPA.perturbations import PerturbationLpNorm +from auto_LiRPA.utils import Flatten + +""" Step 1: Define a `torch.autograd.Function` class to declare and implement the +computation of the operator. """ +class PlusConstantOp(torch.autograd.Function): + @staticmethod + def symbolic(g, x, const): + """ In this function, define the arguments and attributes of the operator. + "custom::PlusConstant" is the name of the new operator, "x" is an argument + of the operator, "const_i" is an attribute which stands for "c" in the operator. + There can be multiple arguments and attributes. For attribute naming, + use a suffix such as "_i" to specify the data type, where "_i" stands for + integer, "_t" stands for tensor, "_f" stands for float, etc. """ + return g.op('custom::PlusConstant', x, const_i=const) + + @staticmethod + def forward(ctx, x, const): + """ In this function, implement the computation for the operator, i.e., + f(x) = x + c in this case. """ + return x + const + +""" Step 2: Define a `torch.nn.Module` class to declare a module using the defined +custom operator. """ +class PlusConstant(nn.Module): + def __init__(self, const=1): + super().__init__() + self.const = const + + def forward(self, x): + """ Use `PlusConstantOp.apply` to call the defined custom operator. """ + return PlusConstantOp.apply(x, self.const) + +""" Step 3: Implement a Bound class to support bound computation for the new operator. """ +class BoundPlusConstant(Bound): + def __init__(self, input_name, name, ori_name, attr, inputs, output_index, options, device): + """ `const` is an attribute and can be obtained from the dict `attr` """ + super().__init__(input_name, name, ori_name, attr, inputs, output_index, options, device) + self.const = attr['const'] + + def forward(self, x): + return x + self.const + + def bound_backward(self, last_lA, last_uA, x): + """ Backward mode bound propagation """ + print('Calling bound_backward for custom::PlusConstant') + def _bound_oneside(last_A): + # If last_lA or last_uA is None, it means lower or upper bound + # is not required, so we simply return None. + if last_A is None: + return None, 0 + # The function f(x) = x + c is a linear function with coefficient 1. + # Then A · f(x) = A · (x + c) = A · x + A · c. + # Thus the new A matrix is the same as the last A matrix: + A = last_A + # For bias, compute A · c and reduce the dimensions by sum: + bias = last_A.sum(dim=list(range(2, last_A.ndim))) * self.const + return A, bias + lA, lbias = _bound_oneside(last_lA) + uA, ubias = _bound_oneside(last_uA) + return [(lA, uA)], lbias, ubias + + def interval_propagate(self, *v): + """ IBP computation """ + print('Calling interval_propagate for custom::PlusConstant') + # Interval bound of the input + h_L, h_U = v[0] + # Since this function is monotonic, we can get the lower bound and upper bound + # by applying the function on h_L and h_U respectively. + lower = h_L + self.const + upper = h_U + self.const + return lower, upper + +""" Step 4: Register the custom operator """ +register_custom_op("custom::PlusConstant", BoundPlusConstant) + +# Use the `PlusConstant` module in model definition +model = nn.Sequential( + Flatten(), + nn.Linear(28 * 28, 256), + PlusConstant(const=1), + nn.Linear(256, 10), +) +print("Model:", model) + +test_data = torchvision.datasets.MNIST("./data", train=False, download=True, transform=torchvision.transforms.ToTensor()) +N = 1 +n_classes = 10 +image = test_data.data[:N].view(N,1,28,28) +true_label = test_data.targets[:N] +image = image.to(torch.float32) / 255.0 +if torch.cuda.is_available(): + image = image.cuda() + model = model.cuda() + +lirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device) + +eps = 0.3 +norm = float("inf") +ptb = PerturbationLpNorm(norm = norm, eps = eps) +image = BoundedTensor(image, ptb) +pred = lirpa_model(image) +label = torch.argmax(pred, dim=1).cpu().detach().numpy() + +for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)']: + print("Bounding method:", method) + lb, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0]) + for i in range(N): + print("Image {} top-1 prediction {} ground-truth {}".format(i, label[i], true_label[i])) + for j in range(n_classes): + indicator = '(ground-truth)' if j == true_label[i] else '' + print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}".format( + j=j, l=lb[i][j].item(), u=ub[i][j].item(), ind=indicator)) + print() + diff --git a/examples/vision/data/ImageNet64/imagenet_data_loader.py b/examples/vision/data/ImageNet64/imagenet_data_loader.py new file mode 100644 index 0000000..2883da0 --- /dev/null +++ b/examples/vision/data/ImageNet64/imagenet_data_loader.py @@ -0,0 +1,43 @@ +import os + +import numpy as np +from PIL import Image + + +class DatasetDownsampledImageNet(): + def __init__(self): + # self.data_path = data_path + os.mkdir('train') + os.mkdir('test') + for i in range(1000): + os.mkdir('train/' + str(i)) + os.mkdir('test/' + str(i)) + print(i) + self.load_data('raw_data/Imagenet64_train_npz', count=0, fname='train/') + self.load_data('raw_data/Imagenet64_val_npz', count=1e8, fname='test/') + + def load_data(self, data_path, img_size=64, count=0., fname=''): + files = os.listdir(data_path) + img_size2 = img_size * img_size + + # count = 0 # 1e8 # test data start with 1 + for file in files: + f = np.load(data_path + '/' + file) + x = np.array(f['data']) + y = np.array(f['labels']) - 1 + x = np.dstack((x[:, :img_size2], x[:, img_size2:2 * img_size2], x[:, 2 * img_size2:])) + x = x.reshape((x.shape[0], img_size, img_size, 3)) + + for i, img in enumerate(x): + img = Image.fromarray(img.reshape(img_size, img_size, 3)) + name = str(int(count)).zfill(9) + label = str(y[i]) + print(count, fname + label + '/' + name + '_label_' + label.zfill(4) + '.png') + # img.show() + img.save(fname + label + '/' + name + '_label_' + label.zfill(4) + '.png') + + count += 1 + + +if __name__ == "__main__": + DatasetDownsampledImageNet() diff --git a/setup.py b/setup.py index 72614ef..0807d51 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ from setuptools import setup, find_packages +import sys """Check PyTorch version""" pytorch_version_l = "1.8.0" @@ -22,6 +23,13 @@ if '__version__' in line: version = eval(line.strip().split()[-1]) +assert sys.version_info.major == 3, 'Python 3 is required' +if sys.version_info.minor < 8: + # numpy 1.22 requires Python 3.8+ + numpy_requirement = 'numpy>=1.16,<=1.21' +else: + numpy_requirement = 'numpy>=1.16' + print(f'Installing auto_LiRPA {version}') setup( name='auto_LiRPA', @@ -34,7 +42,7 @@ install_requires=[ f'torch>={pytorch_version_l},<{pytorch_version_u}', 'torchvision>=0.9,<0.10', - 'numpy>=1.16', + numpy_requirement, 'packaging>=20.0', 'pytest>=5.0', 'appdirs>=1.4',