-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
248 lines (199 loc) · 9.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import cv2
import numpy as np
import argparse
from PyQt5.QtCore import pyqtSignal, QThread, pyqtSlot, Qt, QPoint
from PyQt5.QtGui import QImage, QPixmap, QPainter, QColor, QPen
from PyQt5.QtWidgets import QLabel, QMainWindow, QApplication, QPushButton, QSlider
from PyQt5.uic import loadUi
from time import perf_counter as pc, sleep
import math
from thread_image import ImageThread
# from thread_blur import ImageProcessingThread
from thread_depthanything import get_depth
from click_label import ClickLabel
class MainWindow(QMainWindow):
image_ready = pyqtSignal(np.ndarray)
def __init__(self, path, f_num, kernel_t):
super().__init__()
# Set up UI
self.ui = loadUi('new.ui', self)
self.label = self.findChild(ClickLabel, "label")
# self.slider_src = self.findChild(QSlider, "slider_src")
self.slider_tgt = self.findChild(QSlider, "slider_tgt")
self.button_save = self.findChild(QPushButton, "button_save")
self.image = cv2.imread(path)
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
H, W, C = self.image.shape
self.orig_H = H
self.orig_W = W
fixed_res = 756
res = max(H, W)
scale = fixed_res / res
H, W = int(H*scale), int(W*scale)
self.image = cv2.resize(self.image, (W, H))
# self.image_depth = cv2.imread(path_depth, cv2.IMREAD_GRAYSCALE)
self.image_depth = get_depth(path)
self.image_depth = cv2.resize(self.image_depth, (W, H))
self.setFixedWidth(W)
self.setFixedHeight(H)
# Create attributes
self.f_num = f_num
self.offset_point = QPoint(0, 0)
self.kernel_type = kernel_t
self.display_image(self.image)
self.current_image = self.image
# Define non-linear value mappings for the sliders
self.slider_values = [1.8, 2.0, 2.2, 2.5, 2.8, 3.2, 4.0, 4.5, 5.0, 5.6, 6.3, 7.1, 8.0, 9.0, 10.0, 11.0, 13.0, 16.0]
# Add F-stop labels
self.f_number_label = QLabel(self)
self.f_number_label.setAlignment(Qt.AlignCenter)
self.f_number_label.setStyleSheet("QLabel { font-size: 11pt; margin-bottom: 5px; }")
self.setup_slider(self.slider_tgt)
self.label.clicked.connect(self.handle_label_click)
self.button_save.clicked.connect(self.save_image)
def setup_slider(self, slider):
slider.setMinimum(0)
slider.setMaximum(len(self.slider_values) - 1)
slider.setTickPosition(QSlider.TicksBelow)
slider.setTickInterval(1)
slider.setMinimumHeight(45) # Make room for labels
# Find initial index for f_num
try:
initial_index = self.slider_values.index(self.f_num)
slider.setValue(initial_index)
except ValueError:
slider.setValue(0)
# Connect signal
slider.valueChanged.connect(self.update_slider_value)
def update_slider_value(self, value):
"""Update when slider value changes"""
mapped_value = self.slider_values[value]
self.f_number_label.setText(f"F-number: f/{mapped_value}")
self.f_num = mapped_value
self.update_image()
def paintEvent(self, event):
super().paintEvent(event)
# Draw F-number labels under the slider ticks
if hasattr(self, 'slider_tgt'):
painter = QPainter(self)
painter.setPen(Qt.black)
# Get slider geometry
slider_rect = self.slider_tgt.geometry()
slider_width = slider_rect.width() - 20 # Account for margins
x_offset = slider_rect.x() + 10 # Starting x position
# Get save button geometry to avoid overlap
save_button_rect = self.button_save.geometry()
# Draw labels for every third value
for i, value in enumerate(self.slider_values):
if i % 3 == 0: # Show every third value to avoid crowding
# Calculate x position
x_pos = x_offset + (i * slider_width) / (len(self.slider_values) - 1)
# Move labels above the save button
y_pos = save_button_rect.top() - 5
# Draw text centered on tick mark
text = f"f/{value}"
font_metrics = painter.fontMetrics()
text_width = font_metrics.horizontalAdvance(text)
painter.drawText(int(x_pos - text_width/2), int(y_pos), text)
@pyqtSlot(QPoint)
def handle_label_click(self, pos):
"""Handle the click on the label and trigger adaptive blur."""
self.offset_point = pos + QPoint(0, 28) # bug
self.update_image()
def save_image(self):
"""Save the currently displayed image."""
save_path = f"out_{self.f_num}_{self.offset_point.x(), self.offset_point.y()}.png"
rgb_image = cv2.cvtColor(self.current_image, cv2.COLOR_RGB2BGR)
rgb_orig_size = cv2.resize(rgb_image, (self.orig_W, self.orig_H))
cv2.imwrite(save_path, rgb_orig_size)
print(f"Image saved at {save_path}")
def update_image(self):
depth_masks = self.depth_binning(self.image_depth, num_bins=16)
blurred_image = self.adaptive_blur(depth_masks).astype(np.uint8)
self.display_image(blurred_image)
def display_image(self, rgb_image):
"""Convert and display an RGB image on the label."""
self.image_ready.emit(rgb_image)
# print(f"RGB: {rgb_image.shape}")
self.current_image = rgb_image
h, w, ch = rgb_image.shape
bytes_per_line = ch * w
qt_image = QImage(rgb_image.data, w, h, bytes_per_line, QImage.Format_RGB888)
pixmap = QPixmap.fromImage(qt_image)
pen = QPen(Qt.green)
pen.setWidth(5)
painter = QPainter(pixmap)
painter.setPen(pen)
painter.drawPoint(self.offset_point)
painter.end()
self.label.setPixmap(pixmap)
def adaptive_blur(self, depth_masks):
assert 1 < self.f_num <= 22, "Invalid f-number"
user_sl_point = (self.offset_point.y(), self.offset_point.x())
final_img = np.zeros(shape=self.image.shape)
usr_mask = self.get_user_select_mask_index(user_sl_point, depth_masks)
min_d, max_d = self.image_depth.min(), self.image_depth.max()
def get_real_depth(d):
return 1 + (d - min_d) * (50) / (max_d - min_d)
f = 0.035 # focal length in meters (35 mm)
N = self.f_num
# Get focus depth once
focus_depth = get_real_depth(self.image_depth[depth_masks[usr_mask]].mean())
for i, mask in enumerate(depth_masks):
mask_3ch = np.dstack([mask] * 3)
if usr_mask == i:
final_img += self.image * mask_3ch.astype(np.uint8)
else:
current_depth = get_real_depth(self.image_depth[mask].mean())
# Calculate CoC
numerator = (f * f * abs(focus_depth - current_depth))
denominator = (current_depth * (focus_depth - f))
CoC = abs((f/N) * (numerator / denominator))
CoC = CoC * 2.5
# Convert CoC to pixels
sensor_width = 0.036
image_width = self.image.shape[1]
pixels_per_meter = image_width / sensor_width
k = 60 # scaling constant
CoC_pixels = CoC * pixels_per_meter * k
CoC_pixels = np.clip(CoC_pixels, 0, 60)
radius = CoC_pixels / 2
# Only create and apply blur if the radius is significant
# And the threshold increase with F-stop -> small aparture don't blur
threshold = 0.5 * (N/8.0)
if radius > threshold:
# Create bilateral depth kernel
kernel_size = int(2 * math.ceil(radius) + 1)
y, x = np.ogrid[-kernel_size//2:kernel_size//2+1, -kernel_size//2:kernel_size//2+1]
spatial_kernel = np.exp(-(x*x + y*y) / (2 * radius * radius))
depth_kernel = np.exp(-abs(current_depth - focus_depth) / (2 * radius))
kernel = spatial_kernel * depth_kernel
kernel = kernel / kernel.sum()
blur_image = cv2.filter2D(src=self.image, ddepth=-1, kernel=kernel)
blur_image = cv2.bilateralFilter(blur_image.astype(np.uint8), 9, 75, 75)
else:
blur_image = self.image
final_img += blur_image * mask_3ch
return final_img.astype(np.uint8)
@staticmethod
def depth_binning(img_d, num_bins):
min_intensity, max_intensity = img_d.min(), img_d.max()
bin_edges = np.linspace(min_intensity, max_intensity, num_bins + 1)
masks = [np.logical_and(img_d >= bin_edges[i], img_d < bin_edges[i + 1]) for i in range(num_bins)]
return masks
@staticmethod
def get_user_select_mask_index(coords, masks):
row, col = coords
for i, mask in enumerate(masks):
if mask[row, col]:
return i
return -1
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img_path', type=str, default='examples/00_f0.png', help='Path to image to be refocused')
parser.add_argument('--F', type=float, default=4.0, help='Desired F-number')
args = parser.parse_args()
app = QApplication([])
window = MainWindow(args.img_path, args.F, 'coc') # Removed kernel choice
window.show()
app.exec_()