-
Notifications
You must be signed in to change notification settings - Fork 0
/
gui-demo2.py
308 lines (238 loc) · 10.4 KB
/
gui-demo2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import tkinter as tk
from tkinter import filedialog
import threading
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import numpy as np
from PIL import Image, ImageTk # You need to install the Python Imaging Library (PIL)
# from artificial_artwork._demo import create_algo_runner
from artificial_artwork._main import create_algo_runner
from artificial_artwork.image import convert_to_uint8
# CONSTANTS
IMAGE_COMP_ASSETS = {
'content': {
'load_button_text': "Select Content Image",
'label_text': "Content Image:",
},
'style': {
'load_button_text': "Select Style Image",
'label_text': "Style Image:",
},
}
# width x height
WINDOW_GEOMETRY: str = '2600x1800'
# Content and Style Images rendering dimensions
INPUT_IMAGE_THUMBNAIL_SIZE = (200, 200)
# Generated Image rendering dimensions
GENERATED_IMAGE_THUMBNAIL_SIZE = (500, 500)
# Helpers Objects
img_type_2_path = {}
# Helper Functions
def _build_open_image_dialog_callback(image_file_label, image_type: str):
def _open_image_dialog():
file_path = filedialog.askopenfilename()
if file_path:
img_type_2_path[image_type] = file_path
image_file_label.config(text=f'{IMAGE_COMP_ASSETS[image_type]["label_text"]} {file_path}')
return _open_image_dialog
def _build_open_image_dialog_callback_v2(x, image_type: str):
def _open_file_dialog_v2():
file_path = filedialog.askopenfilename()
if file_path:
image_label = x['image_label']
image_pane = x['image_pane']
img_type_2_path[image_type] = file_path
image = Image.open(file_path)
image.thumbnail(INPUT_IMAGE_THUMBNAIL_SIZE) # Resize the image to fit in the pane
photo = ImageTk.PhotoImage(image=image)
image_pane.config(image=photo)
image_pane.image = photo
image_label.config(text=f'{IMAGE_COMP_ASSETS[image_type]["label_text"]} {file_path}')
image_label.update_idletasks()
return _open_file_dialog_v2
# MAIN
images_components_data = {
'content': dict(
IMAGE_COMP_ASSETS['content'],
image_dialog_from_label=lambda label_obj: _build_open_image_dialog_callback(label_obj, 'content'),
image_dialog=lambda x: _build_open_image_dialog_callback_v2(x, 'content'),
),
'style': dict(
IMAGE_COMP_ASSETS['style'],
image_dialog_from_label=lambda label_obj: _build_open_image_dialog_callback(label_obj, 'style'),
image_dialog=lambda x: _build_open_image_dialog_callback_v2(x, 'style'),
),
}
# Create the main window
root = tk.Tk()
root.title("Neural Style Transfer - Desktop")
# width x height
root.geometry("2600x1800") # Larger window size
# Add a label to describe the purpose of the GUI
description_label = tk.Label(root, text="Select a file using the buttons below:")
description_label.pack(pady=10) # Add padding
# CONTENT IMAGE UI/UX
# BUTTON -> Load Content Image
button1 = tk.Button(
root,
text=images_components_data['content']['load_button_text'],
# command=lambda: images_components_data['content']['image_dialog_from_label'](content_image_label)(),
command=lambda: images_components_data['content']['image_dialog']({
'image_label': content_image_label,
'image_pane': content_image_pane,
})(),
)
button1.pack(pady=5) # Add padding
# LABEL -> Show path of loaded Content Image
content_image_label = tk.Label(root, text=images_components_data['content']['label_text'])
content_image_label.pack()
# LABEL -> PANE to Render the Content Image
content_image_pane = tk.Label(root, width=0, height=0, bg="white") # Set initial dimensions to 0
# content_image_pane = tk.Label(root, width=200, height=200, bg="white")
content_image_pane.pack()
# STYLE IMAGE UI/UX
# BUTTON -> Load Style Image
load_style_image_btn = tk.Button(
root,
text=images_components_data['style']['load_button_text'],
# command=lambda: images_components_data['style']['image_dialog_from_label'](style_image_label)()
command=lambda: images_components_data['style']['image_dialog']({
'image_label': style_image_label,
'image_pane': style_image_pane,
})(),
)
load_style_image_btn.pack(pady=5) # Add padding
# LABEL -> Show path of loaded Style Image
style_image_label = tk.Label(root, text=images_components_data['style']['label_text'])
style_image_label.pack()
# LABEL -> PANE to Render the Style Image
style_image_pane = tk.Label(root, width=0, height=0, bg="white") # Set initial dimensions to 0
# style_image_pane = tk.Label(root, width=200, height=200, bg="white")
style_image_pane.pack()
# GENERATED IMAGE UI/UX
# Helper Update Callback
# def update_image_thread(progress, gen_image_pane, _iteration_count_label, fig, combined_subplot):
# t = threading.Thread(
# target=update_image,
# args=(progress, gen_image_pane, _iteration_count_label, fig, combined_subplot)
# )
# t.start()
#### UPDATE UI based on BACKEND progress ####
# Function to update the GUI with the result from the backend task
# def update_image(progress, gen_image_pane, _iteration_count_label, fig, combined_subplot):
def update_image_thread(progress, gen_image_pane, _iteration_count_label, fig, combined_subplot):
numpy_image_array = progress.state.matrix
current_iteration_count: int = progress.state.metrics['iterations']
# if we have shape of form (1, Width, Height, Number_of_Color_Channels)
if numpy_image_array.ndim == 4 and numpy_image_array.shape[0] == 1:
# reshape to (Width, Height, Number_of_Color_Channels)
matrix = np.reshape(numpy_image_array, tuple(numpy_image_array.shape[1:]))
if str(matrix.dtype) != 'uint8':
matrix = convert_to_uint8(matrix)
image = Image.fromarray(matrix)
# Resize the image to fit in the pane
image.thumbnail(GENERATED_IMAGE_THUMBNAIL_SIZE)
# Convert the image to PhotoImage
photo = ImageTk.PhotoImage(image=image)
# Update the image label with the new image
gen_image_pane.config(image=photo)
gen_image_pane.image = photo
_iteration_count_label.config(text=f'Iteration Count: {current_iteration_count}')
if 'cost' in progress.state.metrics: # backend has evaluated the costs into scalars (floats)
# Update metrics
total_cost_values.append(progress.state.metrics['cost'])
style_cost_values.append(progress.state.metrics['style-cost-weighted'])
content_cost_values.append(progress.state.metrics['content-cost-weighted'])
iteration_values.append(current_iteration_count)
# Update the graph
update_chart(
iteration_values,
total_cost_values,
style_cost_values,
content_cost_values,
combined_subplot
)
################
# LABEL -> Text to display above Live Updated Generated Image
generated_image_label = tk.Label(root, text="Generated Image:")
generated_image_label.pack(pady=10)
# LABEL -> Live Display of Generated Image ! (this will be updated during the learning loop)
generated_image_pane = tk.Label(root, width=0, height=0, bg="white") # Set initial dimensions to 0
generated_image_pane.pack(pady=5)
# ITERATION COUNT UI/UX
# LABEL -> Iteration Count Live Update
iteration_count_label = tk.Label(root, text="Iteration Count:")
iteration_count_label.pack(pady=5)
# RUN NST ALGORITHM UI/UX
# Helper Run Functions
# Run NST Computations in a non-blocking way
def run_nst(fig, combined_subplot):
# Run tf.compat.v1.reset_default_graph()
# and tf.compat.v1.disable_eager_execution()
# Initialize Session as tf.compat.v1.InteractiveSession()
backend_object = create_algo_runner(
iterations=100, # NB of Times to pass Image through the Network
output_folder='gui-output-folder', # Output Folder to store gen img snapshots
noisy_ratio=0.6,
)
observer = type('Observer', (), {
'update': lambda progress: update_image_thread(
progress,
generated_image_pane,
iteration_count_label,
fig, # Pass the Figure to the update function
combined_subplot, # Pass the combined subplot to the update function
),
# 'update': lambda progress: update_image_thread(progress, generated_image_pane, iteration_count_label),
})
backend_object['subscribe'](observer)
content_image_path = img_type_2_path['content']
style_image_path = img_type_2_path['style']
if content_image_path and style_image_path:
backend_object['run'](
content_image_path,
style_image_path,
)
# Define Tread to run the NST Algorithm
def start_nst_thread():
fig, combined_subplot = initialize_graph(root)
nst_thread = threading.Thread(target=run_nst, args=(fig, combined_subplot))
nst_thread.daemon = True # Set as a daemon thread to exit when the main program exits
nst_thread.start()
# BUTTON -> Run NST Algorithm on press
run_nst_btn = tk.Button(
root,
text="Run NST Algorithm",
command=start_nst_thread,
)
run_nst_btn.pack(pady=5) # Add padding
# PLOTTING
total_cost_values = []
style_cost_values = []
content_cost_values = []
iteration_values = []
# Helper Functions
# Initialize Matplotlib figure and subplot
def initialize_graph(root):
fig, combined_subplot = plt.subplots(figsize=(8, 6))
combined_subplot.set_title('Metrics Over Iterations')
combined_subplot.set_xlabel('Iterations')
combined_subplot.set_ylabel('Metric Values')
canvas = FigureCanvasTkAgg(fig, master=root)
canvas_widget = canvas.get_tk_widget()
canvas_widget.pack(side=tk.TOP, fill=tk.BOTH, expand=1)
return fig, combined_subplot
# Update Matplotlib chart with metrics data
def update_chart(_iteration_values, _total_cost_values, _style_cost_values, _content_cost_values, _combined_subplot):
_combined_subplot.clear()
_combined_subplot.plot(_iteration_values, _total_cost_values, label='Total Cost', marker='o')
_combined_subplot.plot(_iteration_values, _style_cost_values, label='Weighted Style Cost', marker='s')
_combined_subplot.plot(_iteration_values, _content_cost_values, label='Weighted Content Cost', marker='x')
_combined_subplot.set_title('Metrics Over Iterations')
_combined_subplot.set_xlabel('Iterations')
_combined_subplot.set_ylabel('Metric Values')
_combined_subplot.legend()
_combined_subplot.figure.canvas.draw()
# TKINTER MAIN LOOP
root.mainloop()