Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug checkRobust #232

Merged
merged 5 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion code/nnv/engine/nn/NN.m
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@
nr = length(outputSet);
R = Star;
for s=1:nr
R = outputSet(s).toStar;
R(s) = outputSet(s).toStar;
end
else
R = outputSet;
Expand Down
3 changes: 3 additions & 0 deletions code/nnv/engine/nn/funcs/PosLin.m
Original file line number Diff line number Diff line change
Expand Up @@ -2073,6 +2073,9 @@
R = PosLin.reach_star_approx2(I, option, dis_opt, lp_solver);
elseif strcmp(method, 'relax-star-range')
R = PosLin.reach_relaxed_star_range(I, relaxFactor, option, dis_opt, lp_solver);
if contains(method, 'reduceMem')
R = R.reduceConstraints;
end
elseif strcmp(method, 'relax-star-bound')
R = PosLin.reach_relaxed_star_bound(I, relaxFactor, option, dis_opt, lp_solver);
elseif strcmp(method, 'relax-star-area')
Expand Down
8 changes: 5 additions & 3 deletions code/nnv/engine/nn/layers/ConcatenationLayer.m
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@

% Get max size of V
vSize = size(inputs{1}.V);
indexMax = 1;
for i = 2:length(inputs)
if numel(size(inputs{i}.V)) > numel(vSize)
vSize = inputs{i}.V;
if numel(inputs{i}.V) > prod(vSize)
vSize = size(inputs{i}.V);
indexMax = i;
end
end

Expand All @@ -125,7 +127,7 @@
end

% Create output set
outputs = ImageStar(new_V, inputs{1}.C, inputs{1}.d, inputs{1}.pred_lb, inputs{1}.pred_ub);
outputs = ImageStar(new_V, inputs{indexMax}.C, inputs{indexMax}.d, inputs{indexMax}.pred_lb, inputs{indexMax}.pred_ub);
end

% handle multiple inputs
Expand Down
34 changes: 33 additions & 1 deletion code/nnv/engine/nn/layers/ReluLayer.m
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,45 @@
h = in_image.height;
w = in_image.width;
c = in_image.numChannel;

reduceMem = 0;
if contains(method, "reduceMem")
method = split(method, '-');
method = strjoin(method(1:end-1),'-');
reduceMem = 1;
end
Y = PosLin.reach(in_image.toStar, method, [], relaxFactor, dis_opt, lp_solver); % reachable set computation with ReLU
n = length(Y);
images(n) = ImageStar;
% transform back to ImageStar
for i=1:n
images(i) = Y(i).toImageStar(h,w,c);
if reduceMem
[l,u] = images(i).estimateRanges;
% Get dimensions
h = images(i).height;
w = images(i).width;
c = images(i).numChannel;
% create ImageStar variables
center = cast(0.5*(u+l), 'like', in_image.V);
v = 0.5*(u-l);
idxRegion = find((l-u));
n = length(idxRegion);
V = zeros(h, w, c, n, 'like', center);
bCount = 1;
for i1 = 1:size(v,1)
for i2 = 1:size(v,2)
basisValue = v(i1,i2);
if basisValue
V(i1,i2,:,bCount) = v(i1,i2);
bCount = bCount + 1;
end
end
end
C = zeros(1, n, 'like', V);
d = zeros(1, 1, 'like', V);
% Construct ImageStar
images(i) = ImageStar(cat(4,center,V), C, d, -1*ones(n,1), ones(n,1), l, u);
end
end
else % star
images = PosLin.reach(in_image, method, [], relaxFactor, dis_opt, lp_solver); % reachable set computation with ReLU
Expand Down
54 changes: 22 additions & 32 deletions code/nnv/engine/set/ImageStar.m
Original file line number Diff line number Diff line change
Expand Up @@ -511,17 +511,19 @@
end

function image = recurrentMap(obj, h_t_1, inputWeight, recurrentWeight, bias)
n = obj.numPred;
N = obj.height*obj.width*obj.numChannel;
for i=1:n+1
I = in_image.V(:,:,:,i);
%I = reshape(I,N,1); % flatten input
if i==1
V(1, 1,:,i) = double(inputWeight)*I + double(recurrentWeight)*h_t_1 + double(bias);
else
V(1, 1,:,i) = double(inputWeight)*I;
end
end
% Whoever implemented this, did not use it as it would return an error
% n = obj.numPred;
% N = obj.height*obj.width*obj.numChannel;
% for i=1:n+1
% I = in_image.V(:,:,:,i);
% %I = reshape(I,N,1); % flatten input
% if i==1
% V(1, 1,:,i) = double(inputWeight)*I + double(recurrentWeight)*h_t_1 + double(bias);
% else
% V(1, 1,:,i) = double(inputWeight)*I;
% end
% end
error("Not supported yet.");
end

function image = HadamardProduct(obj, I)
Expand Down Expand Up @@ -847,30 +849,18 @@
end

if isempty(obj.C) || isempty(obj.d)
error('The imagestar is empty');
warning('The imagestar is empty');
% no ranges, so bounds are the same as image
obj.im_lb = obj.V;
obj.im_ub = obj.V;
end

if isempty(obj.im_lb) || isempty(obj.im_ub)

image_lb = zeros(obj.height, obj.width, obj.numChannel);
image_ub = zeros(obj.height, obj.width, obj.numChannel);
reverseStr = '';
N = obj.height*obj.width*obj.numChannel;
if strcmp(dis_opt, 'display')
fprintf('\nEstimating Range: ');
end
for i=1:obj.height
for j=1:obj.width
for k=1:obj.numChannel
[image_lb(i, j, k), image_ub(i, j, k)] = obj.estimateRange(i,j,k);
if strcmp(dis_opt, 'display')
msg = sprintf('%d/%d', i*j*k, N);
fprintf([reverseStr, msg]);
reverseStr = repmat(sprintf('\b'), 1, length(msg));
end
end
end
end
x1 = obj.V(:,:,:,1) + tensorprod(obj.V(:,:,:,2:end), obj.pred_lb, 4, 1);
x2 = obj.V(:,:,:,1) + tensorprod(obj.V(:,:,:,2:end), obj.pred_ub, 4, 1);
image_lb = min(x1,x2);
image_ub = max(x1,x2);

obj.im_lb = image_lb;
obj.im_ub = image_ub;
Expand Down Expand Up @@ -1523,7 +1513,7 @@ function addInputSize(obj, name, inputSize)
new_d = [in_IS.d; new_d];

end

end


Expand Down
24 changes: 15 additions & 9 deletions code/nnv/examples/NN/FairNNV/adult_exact_verify.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
clear; clc;
modelDir = './adult_onnx'; % Directory containing ONNX models
onnxFiles = dir(fullfile(modelDir, '*.onnx')); % List all .onnx files
onnxFiles = onnxFiles(1); % simplify for debugging

load("adult_data.mat", 'X', 'y'); % Load data once

Expand Down Expand Up @@ -58,12 +59,12 @@
% First, we define the reachability options
reachOptions = struct; % initialize
reachOptions.reachMethod = 'exact-star';
reachOptions.relaxFactor = 0.5;

nR = 50; % ---> just chosen arbitrarily

% ADJUST epsilons value here
epsilon = [0.0,0.001,0.01];
% epsilon = [0.001,0.01];
epsilon = 0.01;

% Set up results
nE = 3;
Expand All @@ -87,19 +88,23 @@
start(verificationTimer); % Start the timer


for i=1:numObs
% for i=1:numObs
for i=57
idx = rand_indices(i);
IS = perturbationIF(X_test_loaded(:, idx), epsilon(e), min_values, max_values);


unsafeRegion = net.robustness_set(y_test_loaded(idx), 'min');

t = tic; % Start timing the verification for each sample

temp = net.verify_robustness(IS, reachOptions, y_test_loaded(idx));
temp = net.verify_robustness(IS, reachOptions, unsafeRegion);
met(i,e) = 'exact';
res(i,e) = temp; % robust result
% end

res(i,e) = temp; % robust result
time(i,e) = toc(t); % store computation time

if ~(temp == 1)
counterExs = getCounterRegion(IS,unsafeRegion,net.reachSet{end});
end

% Check for timeout flag
if evalin('base', 'timeoutOccurred')
Expand Down Expand Up @@ -164,6 +169,7 @@
% Calculate disturbed lower and upper bounds considering min and max values
lb = max(x - disturbance, min_values);
ub = min(x + disturbance, max_values);
IS = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models)
IS = Star(single(lb), single(ub)); % default: single (assume onnx input models)

end

11 changes: 6 additions & 5 deletions code/nnv/examples/NN/FairNNV/adult_verifiy.m
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
nR = 50; % ---> just chosen arbitrarily

% ADJUST epsilon values here
epsilon = [0.0,0.001,0.01];
epsilon = [0.001,0.01];


% Set up results
Expand All @@ -90,7 +90,7 @@
start(verificationTimer); % Start the timer

% Iterate through observations
for i=1:numObs
for i=38
idx = rand_indices(i);
[IS, xRand] = perturbationIF(X_test_loaded(:, idx), epsilon(e), nR, min_values, max_values);

Expand All @@ -109,6 +109,7 @@
time(i,e) = toc(t);
met(i,e) = "counterexample";
skipTryCatch = true; % Set the flag to skip try-catch block
disp('Counter example found');
continue;
end
end
Expand Down Expand Up @@ -198,7 +199,7 @@
% Calculate disturbed lower and upper bounds considering min and max values
lb = max(x - disturbance, min_values);
ub = min(x + disturbance, max_values);
IS = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models)
IS = Star(single(lb), single(ub)); % default: single (assume onnx input models)

% Create random samples from initial set
% Adjusted reshaping according to specific needs
Expand All @@ -208,7 +209,7 @@
xRand = xB.sample(nR);
xRand = reshape(xRand,[13,nR]);
xRand(:,nR+1) = x; % add original image
xRand(:,nR+2) = IS.im_lb; % add lower bound image
xRand(:,nR+3) = IS.im_ub; % add upper bound image
xRand(:,nR+2) = xB.lb; % add lower bound image
xRand(:,nR+3) = xB.ub; % add upper bound image
end

38 changes: 38 additions & 0 deletions code/nnv/examples/NN/FairNNV/getCounterRegion.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
function counterExamples = getCounterRegion(inputSet, unsafeRegion, reachSet)
% counterExamples = getCounterRegion(inputSet, unsafeRegion, reachSet)
% NOTE: This is only to be used with exact-star method
% unsafeRegion = HalfSpace (unsafe/undesired region)
% inputSet = ImageStar/Star
% reachSet = Star
%
% check the "safety" of the reachSet
% Then, generate counterexamples

% Initialize variables
counterExamples = [];

% Get halfspace variables
G = unsafeRegion.G;
g = unsafeRegion.g;

% Check for valid inputs
if ~isa(inputSet, "Star")
error("Must be a Star");
end
if ~isa(reachSet, "Star")
error("Must be Star or ImageStar");
end

% Begin counterexample computation
n = length(reachSet); % number of stars in the output set
V = inputSet.V;
for i=1:n
% Check for safety, if unsafe, add to counter
if ~isempty(reachSet(i).intersectHalfSpace(G, g))
counterExamples = [counterExamples Star(V, reachSet(i).C, reachSet(i).d,...
reachSet(i).predicate_lb, reachSet(i).predicate_ub)];
end
end

end

This file was deleted.

Loading