forked from Tencent/DepthCrafter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gui.py
292 lines (260 loc) · 12 KB
/
gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import threading
import gc
import os
import glob
import shutil
import json
import numpy as np
import torch
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
from diffusers.training_utils import set_seed
from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
from depthcrafter.utils import save_video, read_video_frames
class DepthCrafterDemo:
"""
Class to handle the DepthCrafter inference.
"""
def __init__(self, unet_path: str, pre_train_path: str, cpu_offload: str = "model"):
"""
Initializes the DepthCrafter pipeline.
Args:
unet_path (str): Path to the UNet model.
pre_train_path (str): Path to the pre-trained model.
cpu_offload (str, optional): CPU offload strategy ('model', 'sequential'). Defaults to "model".
"""
unet = DiffusersUNetSpatioTemporalConditionModelDepthCrafter.from_pretrained(
unet_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
self.pipe = DepthCrafterPipeline.from_pretrained(
pre_train_path,
unet=unet,
torch_dtype=torch.float16,
variant="fp16",
)
if cpu_offload == "sequential":
self.pipe.enable_sequential_cpu_offload()
elif cpu_offload == "model":
self.pipe.enable_model_cpu_offload()
else:
raise ValueError(f"Unknown CPU offload option: {cpu_offload}")
self.pipe.enable_attention_slicing()
def infer(self, video, num_denoising_steps, guidance_scale, save_folder, window_size, process_length, overlap, max_res, seed):
"""
Performs depth inference on a video.
Args:
video (str): Path to the input video.
num_denoising_steps (int): Number of denoising steps.
guidance_scale (float): Guidance scale for inference.
save_folder (str): Folder to save the depth map video.
window_size (int): Window size for temporal processing.
process_length (int): Number of frames to process.
overlap (int): Overlap between windows.
max_res (int): Maximum resolution of the video.
seed (int): Random seed for reproducibility.
Returns:
str: The save path of the depth map video
"""
set_seed(seed)
frames, target_fps = read_video_frames(video, process_length, -1, max_res, "open")
with torch.inference_mode():
res = self.pipe(
frames,
height=frames.shape[1],
width=frames.shape[2],
output_type="np",
guidance_scale=guidance_scale,
num_inference_steps=num_denoising_steps,
window_size=window_size,
overlap=overlap,
).frames[0]
res = res.sum(-1) / res.shape[-1]
res = (res - res.min()) / (res.max() - res.min())
save_path = os.path.join(save_folder, os.path.splitext(os.path.basename(video))[0])
os.makedirs(save_folder, exist_ok=True)
save_video(res, save_path + "_depth.mp4", fps=target_fps)
return save_path + "_depth.mp4"
def run(self, video, **kwargs):
"""
Runs the depth inference and handles cleanup.
Args:
video (str): Path to the input video.
**kwargs: Additional parameters for inference.
"""
self.infer(video, **kwargs)
gc.collect()
torch.cuda.empty_cache()
class DepthCrafterGUI:
"""
GUI class for the DepthCrafter application.
"""
CONFIG_FILENAME = "config.json"
def __init__(self, root):
"""
Initializes the GUI.
Args:
root (tk.Tk): The main Tkinter window.
"""
self.root = root
self.root.title("DepthCrafter GUI")
# Default values before loading config
self.input_dir = tk.StringVar(value="./input_clips")
self.output_dir = tk.StringVar(value="./output_depthmaps")
self.guidance_scale = tk.DoubleVar(value=1.0)
self.inference_steps = tk.IntVar(value=5)
self.window_size = tk.IntVar(value=110)
self.max_res = tk.IntVar(value=960)
self.overlap = tk.IntVar(value=25)
self.seed = tk.IntVar(value=42)
self.cpu_offload = tk.StringVar(value="model")
# Attempt to load config from file
self.load_config()
self.processing_thread = None
self.create_widgets()
# Ensure settings are saved on exit
self.root.protocol("WM_DELETE_WINDOW", self.on_close)
def create_widgets(self):
"""Creates and arranges all the GUI widgets."""
# Input/Output Folders
frame = tk.LabelFrame(self.root, text="Directories")
frame.pack(fill="x", padx=10, pady=5)
tk.Label(frame, text="Input Folder:").grid(row=0, column=0, sticky="e")
tk.Entry(frame, textvariable=self.input_dir, width=50).grid(row=0, column=1)
tk.Button(frame, text="Browse", command=self.browse_input).grid(row=0, column=2)
tk.Label(frame, text="Output Folder:").grid(row=1, column=0, sticky="e")
tk.Entry(frame, textvariable=self.output_dir, width=50).grid(row=1, column=1)
tk.Button(frame, text="Browse", command=self.browse_output).grid(row=1, column=2)
# Parameters
param_frame = tk.LabelFrame(self.root, text="Parameters")
param_frame.pack(fill="x", padx=10, pady=5)
self.add_param(param_frame, "Guidance Scale", self.guidance_scale, 0)
self.add_param(param_frame, "Inference Steps", self.inference_steps, 1)
self.add_param(param_frame, "Window Size", self.window_size, 2)
self.add_param(param_frame, "Max Resolution", self.max_res, 3)
self.add_param(param_frame, "Overlap", self.overlap, 4)
self.add_param(param_frame, "Seed", self.seed, 5)
tk.Label(param_frame, text="CPU Offload Mode:").grid(row=6, column=0, sticky="e")
cpu_offload_box = ttk.Combobox(
param_frame, textvariable=self.cpu_offload, values=["model", "sequential"]
)
cpu_offload_box.grid(row=6, column=1, padx=5)
# Controls
ctrl_frame = tk.Frame(self.root)
ctrl_frame.pack(pady=10)
tk.Button(ctrl_frame, text="Start", command=self.start_thread).pack(side="left", padx=5)
tk.Button(ctrl_frame, text="Exit", command=self.on_close).pack(side="right", padx=5)
# Logs
log_frame = tk.LabelFrame(self.root, text="Log")
log_frame.pack(fill="both", expand=True, padx=10, pady=5)
self.log = tk.Text(log_frame, state="disabled", height=10)
self.log.pack(fill="both", expand=True)
def add_param(self, parent, label, var, row):
"""Helper function to create a parameter entry field."""
tk.Label(parent, text=label + ":").grid(row=row, column=0, sticky="e")
tk.Entry(parent, textvariable=var).grid(row=row, column=1, padx=5, pady=2)
def browse_input(self):
"""Opens a file dialog to select the input folder and uses os.path.normpath to fix path formatting"""
folder = filedialog.askdirectory(initialdir=os.path.normpath(self.input_dir.get()))
if folder:
self.input_dir.set(os.path.normpath(folder))
def browse_output(self):
"""Opens a file dialog to select the output folder and uses os.path.normpath to fix path formatting"""
folder = filedialog.askdirectory(initialdir=os.path.normpath(self.output_dir.get()))
if folder:
self.output_dir.set(os.path.normpath(folder))
def log_message(self, message):
"""Logs a message to the GUI log."""
self.log.config(state="normal")
self.log.insert("end", message + "\n")
self.log.config(state="disabled")
self.log.see("end")
def start_thread(self):
"""Starts a new thread for processing."""
if self.processing_thread is None or not self.processing_thread.is_alive():
self.processing_thread = threading.Thread(target=self.start_processing, daemon=True)
self.processing_thread.start()
def start_processing(self):
"""
Main processing logic.
"""
try:
self.log_message("Starting processing...")
demo = DepthCrafterDemo(
unet_path="tencent/DepthCrafter",
pre_train_path="stabilityai/stable-video-diffusion-img2vid-xt",
cpu_offload=self.cpu_offload.get(),
)
for ext in ["*.mp4", "*.avi", "*.mov", "*.mkv"]:
videos = glob.glob(os.path.join(self.input_dir.get(), ext))
finished_folder = os.path.join(self.input_dir.get(), "finished")
# Ensure the 'finished' folder exists
os.makedirs(finished_folder, exist_ok=True)
for video in videos:
self.log_message(f"Processing: {video}")
demo.run(
video,
num_denoising_steps=self.inference_steps.get(),
guidance_scale=self.guidance_scale.get(),
save_folder=self.output_dir.get(),
window_size=self.window_size.get(),
process_length=-1,
overlap=self.overlap.get(),
max_res=self.max_res.get(),
seed=self.seed.get(),
)
shutil.move(video, finished_folder)
self.log_message("Processing complete!")
except Exception as e:
messagebox.showerror("Error", str(e))
def on_close(self):
"""
Saves the config and closes the main window.
"""
# Save configuration before closing
self.save_config()
self.root.destroy()
def save_config(self):
"""
Saves the current settings to the configuration file.
"""
config = {
"input_dir": self.input_dir.get(),
"output_dir": self.output_dir.get(),
"guidance_scale": self.guidance_scale.get(),
"inference_steps": self.inference_steps.get(),
"window_size": self.window_size.get(),
"max_res": self.max_res.get(),
"overlap": self.overlap.get(),
"seed": self.seed.get(),
"cpu_offload": self.cpu_offload.get(),
}
with open(self.CONFIG_FILENAME, "w") as f:
json.dump(config, f, indent=4)
def load_config(self):
"""
Loads settings from the configuration file.
"""
if os.path.exists(self.CONFIG_FILENAME):
try:
with open(self.CONFIG_FILENAME, "r") as f:
config = json.load(f)
# Use os.path.normpath to ensure path correctness
self.input_dir.set(os.path.normpath(config.get("input_dir", "./input_clips")))
self.output_dir.set(os.path.normpath(config.get("output_dir", "./output_depthmaps")))
self.guidance_scale.set(config.get("guidance_scale", 1.0))
self.inference_steps.set(config.get("inference_steps", 5))
self.window_size.set(config.get("window_size", 110))
self.max_res.set(config.get("max_res", 960))
self.overlap.set(config.get("overlap", 25))
self.seed.set(config.get("seed", 42))
self.cpu_offload.set(config.get("cpu_offload", "model"))
except Exception as e:
# If there's an error reading config, just use defaults
print(f"Warning: Could not load config: {e}")
if __name__ == "__main__":
root = tk.Tk()
app = DepthCrafterGUI(root)
root.mainloop()