-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #238 from mldiego/master
work in progress, fix github issues
- Loading branch information
Showing
357 changed files
with
876 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
function I = remove_voxels(vol, voxels, noise_disturbance) | ||
% noise_disturnamce can be kept fixed here, more interesting on number | ||
% of voxels changed | ||
|
||
% Return a VolumeStar of a brightening attack on a few pixels | ||
|
||
% Initialize vars | ||
vol = single(vol); | ||
at_vol = vol; | ||
|
||
% we can find the edge of the shape | ||
shape = edge3(vol,'approxcanny',0.6); % this should be okay for this data, but let's test it | ||
|
||
% select a random pixel | ||
idxs = intersect(find(shape),find(vol)); | ||
voxels = min(voxels,length(idxs)); | ||
|
||
% For now, we can select the first ones | ||
at_vol(idxs(1:voxels)) = 0; | ||
|
||
% Define input set as VolumeStar | ||
dif_vol = -vol + at_vol; | ||
noise = dif_vol; | ||
V(:,:,:,:,1) = vol; % center of set | ||
V(:,:,:,:,2) = noise; % basis vectors | ||
C = [1; -1]; % constraints | ||
d = [1; noise_disturbance-1]; % constraints | ||
I = VolumeStar(V, C, d, 1-noise_disturbance, 1); % input set | ||
|
||
|
||
end |
2 changes: 0 additions & 2 deletions
2
code/nnv/examples/Submission/WiP_3d/functions/summarize_results.m
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
|
||
# NAV Benchmark | ||
|
||
## Property: | ||
The control goal is to navigate a robot to a goal region while avoiding an obstacle. | ||
Time horizon: `t = 6s`. Control period: `0.2s`. | ||
|
||
Initial states: | ||
|
||
x1 = [2.9, 3.1] | ||
x2 = [2.9, 3.1] | ||
x3 = [0, 0] | ||
x4 = [0, 0] | ||
|
||
Dynamic system: [dynamics.m](./dynamics.m) | ||
|
||
Goal region ( t=6 ): | ||
|
||
x1 = [-0.5, 0.5] | ||
x2 = [-0.5, 0.5] | ||
x3 = [-Inf, Inf] | ||
x4 = [-Inf, Inf] | ||
|
||
Obstacle ( always ): | ||
|
||
x1 = [1, 2] | ||
x2 = [1, 2] | ||
x3 = [-Inf, Inf] | ||
x4 = [-Inf, Inf] | ||
|
||
## Networks: | ||
|
||
We provide two networks: | ||
- The first network is trained with standard (point-based) reinforcement learning: `nn-nav-point.onnx` | ||
- The second network is trained set-based to improve its verifiable robustness by integrating reachability analysis into the training process: `nn-nav-set.onnx` | ||
|
||
Reference set-based training: https://arxiv.org/abs/2401.14961 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
function dx = dynamics(x,u) | ||
|
||
dx = [ | ||
x(3)*cos(x(4)); | ||
x(3)*sin(x(4)); | ||
u(1); | ||
u(2) | ||
]; | ||
|
||
end | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file added
BIN
+10.4 KB
code/nnv/examples/Submission/WiP_3d/other/NAV/networks/nn-nav-point.onnx
Binary file not shown.
Binary file added
BIN
+10.4 KB
code/nnv/examples/Submission/WiP_3d/other/NAV/networks/nn-nav-set.onnx
Binary file not shown.
135 changes: 135 additions & 0 deletions
135
code/nnv/examples/Submission/WiP_3d/other/NAV/reach_point.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
function rT = reach_point() | ||
|
||
%% Reachability analysis of NAV Benchmark | ||
|
||
%% Load Components | ||
|
||
% Load the controller | ||
%netonnx = importNetworkFromONNX('networks/nn-nav-point.onnx', "InputDataFormats", "BC"); | ||
netonnx = importONNXNetwork('networks/nn-nav-point.onnx', "InputDataFormats", "BC"); | ||
|
||
% Load plant | ||
reachStep = 0.02; | ||
controlPeriod = 0.2; | ||
% plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4)); | ||
% plant.set_tensorOrder(2); | ||
% plant.set_taylorTerms(3); | ||
% plant.set_zonotopeOrder(100); | ||
% plant.set_intermediateOrder(50); | ||
|
||
%% Reachability analysis | ||
|
||
% Initial set | ||
lb = [2.9; 2.9; 0; 0]; | ||
ub = [3.1; 3.1; 0; 0]; | ||
init_set = Box(lb,ub); | ||
init = init_set.partition([1 2],[50 50]); | ||
|
||
% Reachability options | ||
num_steps = 20; | ||
reachOptions.reachMethod = 'approx-star'; | ||
|
||
N = length(init); | ||
disp("Verifying "+string(N)+" samples...") | ||
|
||
mkdir('tmp'); | ||
parpool("Processes"); % initialize parallel process | ||
|
||
% Execute reachabilty analysis | ||
t = tic; | ||
parfor j = 1:length(init) | ||
% Get NNV network | ||
net = matlab2nnv(netonnx); | ||
% Create plant | ||
plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4)); | ||
plant.set_tensorOrder(2); | ||
plant.set_taylorTerms(3); | ||
plant.set_zonotopeOrder(100); | ||
plant.set_intermediateOrder(50); | ||
% Get initial conditions | ||
init_set = init(j).toStar; | ||
%reachSub = init_set; | ||
for i = 1:num_steps | ||
% Compute controller output set | ||
input_set = net.reach(init_set,reachOptions); | ||
|
||
% Compute plant reachable set | ||
init_set = plantReach(plant, init_set, input_set,'lin'); | ||
end | ||
toc(t); | ||
parsave("tmp/reachSet"+string(j)+".mat",plant); | ||
end | ||
rT = toc(t); % get reach time | ||
disp("Finished reachability...") | ||
|
||
% Shut Down Current Parallel Pool | ||
poolobj = gcp('nocreate'); | ||
delete(poolobj); | ||
|
||
% Save results | ||
if is_codeocean | ||
save('/results/logs/nav_point.mat', 'rT','-v7.3'); | ||
else | ||
save('nav_point.mat', 'rT','-v7.3'); | ||
end | ||
|
||
|
||
%% Visualize results | ||
setFiles = dir('tmp/*.mat'); | ||
|
||
t = tic; | ||
|
||
f = figure; | ||
rectangle('Position',[-0.5,-0.5,1,1],'FaceColor',[0 0.5 0 0.5],'EdgeColor','y', 'LineWidth', 0.1); % goal region | ||
hold on; | ||
rectangle('Position',[1,1,1,1],'FaceColor',[0.7 0 0 0.8], 'EdgeColor','r', 'LineWidth', 0.1); % obstacle | ||
grid; | ||
for K = 1 : length(setFiles) | ||
if ~mod(K,50) | ||
disp("Plotting partition "+string(K)+" ..."); | ||
toc(t) | ||
pause(0.01); % to ensure it prints | ||
end | ||
res = load("tmp/"+setFiles(K).name); | ||
plant = res.plant; | ||
for k=1:length(plant.cora_set) | ||
plot(plant.cora_set{k}, [1,2], 'b', 'Unify', true); | ||
end | ||
end | ||
hold on; | ||
xlabel('x1'); | ||
ylabel('x2'); | ||
|
||
disp("Finished plotting all reach sets"); | ||
|
||
%% Save figure | ||
if is_codeocean | ||
saveas(f,'/results/logs/nav_point.png'); | ||
% exportgraphics(f,'/results/logs/nav-set.pdf', 'ContentType', 'vector'); | ||
else | ||
saveas(f,'nav_point_21.png'); | ||
% exportgraphics(f,'nav-set.pdf','ContentType', 'vector'); | ||
end | ||
|
||
% Save results | ||
if is_codeocean | ||
save('/results/logs/nav_point.mat','rT','-v7.3'); | ||
else | ||
save('nav_point.mat', 'rT','-v7.3'); | ||
end | ||
|
||
end | ||
|
||
%% Helper function | ||
function init_set = plantReach(plant,init_set,input_set,algoC) | ||
nS = length(init_set); % based on approx-star, number of sets should be equal | ||
ss = []; | ||
for k=1:nS | ||
ss =[ss plant.stepReachStar(init_set(k), input_set(k),algoC)]; | ||
end | ||
init_set = ss; | ||
end | ||
|
||
function parsave(fname, plant) % trick to save while on parpool | ||
save(fname, 'plant') | ||
end |
Oops, something went wrong.