-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaquila.py
105 lines (71 loc) · 2.3 KB
/
aquila.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# _ _
# __ _ __ _ _ _(_) | __ _
# / _` |/ _` | | | | | |/ _` |
# | (_| | (_| | |_| | | | (_| |
# \__,_|\__, |\__,_|_|_|\__,_|
# |_|
#
import time
import matplotlib.pyplot as plt
import numpy as np
from tkinter import filedialog
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
from PIL import Image
from inundatio import get_houses, pad_bbox
from perses.dnet import Net as Perses
from unet import UNet
from src.cli import opt
from src.data import combine_maps, crop
from src.vis import add_padding, draw_bbox
THEIA_MODEL = opt.theia
PERSES_MODEL = opt.perses
PRE = opt.pre or filedialog.askopenfilename(title='Choose pre-disaster image')
POST = opt.post or filedialog.askopenfilename(
title='Choose post-disaster image')
assert PRE and POST
before = Image.open(PRE)
after = Image.open(POST)
assert before.size == after.size
SIZE = before.size[0]
dev = torch.device('cuda' if (torch.cuda.is_available()
and not opt.no_cuda) else 'cpu')
print('Using device "%s" for calculation' % dev)
# theia
theia = UNet(in_channels=3, out_channels=1, padding=True)
theia.load_state_dict(torch.load(THEIA_MODEL))
theia.eval()
theia = theia.to(dev)
# perses
perses = Perses()
perses.load_state_dict(torch.load(PERSES_MODEL))
perses.eval()
perses = perses.to(dev)
t0 = time.perf_counter()
before = crop(before)
seg_maps = []
for img in before:
seg_map = theia(img.unsqueeze(0).to(dev))
seg_map = torch.sigmoid(seg_map).squeeze().detach().to('cpu')
seg_map[seg_map < opt.threshhold] = 0
seg_map[seg_map > 0] = 1
seg_maps.append(seg_map)
seg_map = combine_maps(seg_maps).squeeze().detach().numpy()
coords = get_houses(seg_map)
coords = [add_padding(box, clip_max=SIZE) for box in coords]
fig, ax = plt.subplots(1)
ax.imshow(after)
for bbox in coords:
bbox_image = after.crop((bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]))
x = TF.resize(bbox_image, (75, 75))
x = TF.to_tensor(x)
damage = perses(x.to(dev).unsqueeze(0))
damage = torch.sigmoid(damage).detach().squeeze().to('cpu')
draw_bbox(ax, bbox, damage)
t1 = time.perf_counter()
if opt.time:
print('Took %.3f s to evaluate' % (t1 - t0))
plt.savefig(opt.save) if opt.save else plt.show()