-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdepth_estimation_node.py
336 lines (273 loc) · 13.6 KB
/
depth_estimation_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import os
import numpy as np
import torch
from transformers import pipeline
from PIL import Image, ImageFilter, ImageOps
import folder_paths
from comfy.model_management import get_torch_device, get_free_memory
import gc
import logging
from typing import Tuple, List, Dict, Any, Optional, Union
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DepthEstimation")
# Configure model paths
if not hasattr(folder_paths, "models_dir"):
folder_paths.models_dir = os.path.join(folder_paths.base_path, "models")
# Register depth models path
DEPTH_DIR = "depth_anything"
folder_paths.folder_names_and_paths[DEPTH_DIR] = ([
os.path.join(folder_paths.models_dir, DEPTH_DIR)
], folder_paths.supported_pt_extensions)
# Set models directory
MODELS_DIR = folder_paths.folder_names_and_paths[DEPTH_DIR][0][0]
os.makedirs(MODELS_DIR, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = MODELS_DIR
# Define all models mentioned in the README
DEPTH_MODELS = {
"Depth-Anything-Small": "LiheYoung/depth-anything-small",
"Depth-Anything-Base": "LiheYoung/depth-anything-base",
"Depth-Anything-Large": "LiheYoung/depth-anything-large",
"Depth-Anything-V2-Small": "LiheYoung/depth-anything-small-hf",
"Depth-Anything-V2-Base": "LiheYoung/depth-anything-base-hf",
}
class DepthEstimationNode:
"""
ComfyUI node for depth estimation using Depth Anything models.
This node provides depth map generation from images using various Depth Anything models
with configurable post-processing options like blur, median filtering, contrast enhancement,
and gamma correction.
"""
MEDIAN_SIZES = ["3", "5", "7", "9", "11"]
def __init__(self):
self.device = None
self.depth_estimator = None
self.current_model = None
logger.info("Initialized DepthEstimationNode")
@classmethod
def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
"""Define the input types for the node."""
return {
"required": {
"image": ("IMAGE",),
"model_name": (list(DEPTH_MODELS.keys()),),
"blur_radius": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"median_size": (cls.MEDIAN_SIZES, {"default": "5"}),
"apply_auto_contrast": ("BOOLEAN", {"default": True}),
"apply_gamma": ("BOOLEAN", {"default": True})
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "estimate_depth"
CATEGORY = "depth"
def cleanup(self) -> None:
"""Clean up resources and free VRAM."""
try:
if self.depth_estimator is not None:
del self.depth_estimator
self.depth_estimator = None
self.current_model = None
# Force CUDA cache clearing
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("Cleaned up model resources")
except Exception as e:
logger.warning(f"Error during cleanup: {e}")
def ensure_model_loaded(self, model_name: str) -> None:
"""
Ensures the correct model is loaded with proper VRAM management and fallback options.
Args:
model_name: The name of the model to load
Raises:
RuntimeError: If the model fails to load after all fallback attempts
"""
try:
if model_name not in DEPTH_MODELS:
raise ValueError(f"Unknown model: {model_name}. Available models: {list(DEPTH_MODELS.keys())}")
model_path = DEPTH_MODELS[model_name]
# Only reload if needed
if self.depth_estimator is None or self.current_model != model_path:
self.cleanup()
# Set up device
if self.device is None:
self.device = get_torch_device()
logger.info(f"Loading depth model: {model_name} on device {self.device}")
# Determine device type for pipeline
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
# Use FP16 for CUDA devices to save VRAM
dtype = torch.float16 if 'cuda' in str(self.device) else torch.float32
# Create a dedicated cache directory for this model
cache_dir = os.path.join(MODELS_DIR, model_name.replace("-", "_").lower())
os.makedirs(cache_dir, exist_ok=True)
# List of model paths to try (original and fallback)
model_paths_to_try = [
model_path, # Original path
model_path + "-hf", # Try with -hf suffix
model_path.replace("depth-anything", "depth-anything-hf") # Alternative format
]
# Try each model path
success = False
last_error = None
for path in model_paths_to_try:
try:
logger.info(f"Attempting to load from: {path}")
# Try with online mode first
try:
self.depth_estimator = pipeline(
"depth-estimation",
model=path,
cache_dir=cache_dir,
local_files_only=False, # Try online first
device_map=device_type,
torch_dtype=dtype
)
success = True
logger.info(f"Successfully loaded model from {path}")
break
except Exception as online_error:
logger.warning(f"Online loading failed for {path}: {str(online_error)}")
# Try with local_files_only if online fails
try:
self.depth_estimator = pipeline(
"depth-estimation",
model=path,
cache_dir=cache_dir,
local_files_only=True, # Try local only as fallback
device_map=device_type,
torch_dtype=dtype
)
success = True
logger.info(f"Successfully loaded model from local cache: {path}")
break
except Exception as local_error:
last_error = local_error
logger.warning(f"Local loading failed for {path}: {str(local_error)}")
continue
except Exception as path_error:
last_error = path_error
logger.warning(f"Failed to load model from {path}: {str(path_error)}")
continue
if not success:
# If all attempts failed, show helpful message with instructions
error_msg = f"""
Failed to load model {model_name} after trying multiple sources.
Last error: {str(last_error)}
Try these solutions:
1. Run 'huggingface-cli login' in your terminal to authenticate
2. Check your internet connection
3. Try a different model version (e.g. Depth-Anything-V2-Small instead of Depth-Anything-Small)
"""
logger.error(error_msg)
raise RuntimeError(error_msg)
# Ensure model is on the correct device
if hasattr(self.depth_estimator, 'model'):
self.depth_estimator.model = self.depth_estimator.model.to(self.device)
self.current_model = model_path
except Exception as e:
self.cleanup()
error_msg = f"Failed to load model {model_name}: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def process_image(self, image: Union[torch.Tensor, np.ndarray]) -> Image.Image:
"""
Converts input image to proper format for depth estimation.
Args:
image: Input image as tensor or numpy array
Returns:
PIL Image ready for depth estimation
"""
if torch.is_tensor(image):
image_np = (image.cpu().numpy()[0] * 255).astype(np.uint8)
else:
image_np = (image * 255).astype(np.uint8)
if len(image_np.shape) == 3:
if image_np.shape[-1] == 4: # Handle RGBA images
image_np = image_np[..., :3]
elif len(image_np.shape) == 2: # Handle grayscale images
image_np = np.stack([image_np] * 3, axis=-1)
return Image.fromarray(image_np)
def estimate_depth(self,
image: torch.Tensor,
model_name: str,
blur_radius: float = 2.0,
median_size: str = "5",
apply_auto_contrast: bool = True,
apply_gamma: bool = True) -> Tuple[torch.Tensor]:
"""
Estimates depth from input image with error handling and cleanup.
Args:
image: Input image tensor
model_name: Name of the depth model to use
blur_radius: Gaussian blur radius for smoothing
median_size: Size of median filter for noise reduction
apply_auto_contrast: Whether to enhance contrast automatically
apply_gamma: Whether to apply gamma correction
Returns:
Tuple containing depth map tensor
Raises:
RuntimeError: If depth estimation fails
ValueError: If invalid parameters are provided
"""
try:
if median_size not in self.MEDIAN_SIZES:
raise ValueError(f"Invalid median_size. Must be one of {self.MEDIAN_SIZES}")
self.ensure_model_loaded(model_name)
pil_image = self.process_image(image)
with torch.inference_mode():
depth_result = self.depth_estimator(pil_image)
depth_map = depth_result["predicted_depth"].squeeze().cpu().numpy()
# Normalize depth values
depth_min, depth_max = depth_map.min(), depth_map.max()
if depth_max > depth_min:
depth_map = ((depth_map - depth_min) / (depth_max - depth_min) * 255.0)
depth_map = depth_map.astype(np.uint8)
# Create PIL image explicitly with L mode (grayscale)
depth_pil = Image.fromarray(depth_map, mode='L')
# Apply post-processing
if blur_radius > 0:
depth_pil = depth_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius))
if int(median_size) > 0:
depth_pil = depth_pil.filter(ImageFilter.MedianFilter(size=int(median_size)))
if apply_auto_contrast:
depth_pil = ImageOps.autocontrast(depth_pil)
if apply_gamma:
depth_array = np.array(depth_pil).astype(np.float32) / 255.0
mean_luminance = np.mean(depth_array)
if mean_luminance > 0:
gamma = np.log(0.5) / np.log(mean_luminance)
# Use direct numpy operations for gamma correction
corrected = np.power(depth_array, 1.0/gamma) * 255.0
depth_pil = Image.fromarray(corrected.astype(np.uint8), mode='L')
# Convert to tensor - explicitly handle as grayscale
depth_array = np.array(depth_pil).astype(np.float32) / 255.0
# Make it compatible with ComfyUI by creating a 3-channel image
# Use proper reshaping to avoid dimension issues
h, w = depth_array.shape
depth_rgb = np.stack([depth_array] * 3, axis=-1) # Create proper 3D array with shape (h, w, 3)
depth_tensor = torch.from_numpy(depth_rgb).unsqueeze(0)
if self.device is not None:
depth_tensor = depth_tensor.to(self.device)
return (depth_tensor,)
except Exception as e:
error_msg = f"Depth estimation failed: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
finally:
torch.cuda.empty_cache()
gc.collect()
def gamma_correction(self, img: Image.Image, gamma: float = 1.0) -> Image.Image:
"""Applies gamma correction to the image."""
# Convert to numpy array
img_array = np.array(img)
# Apply gamma correction directly with numpy
corrected = np.power(img_array.astype(np.float32) / 255.0, 1.0/gamma) * 255.0
# Ensure uint8 type and create image with explicit mode
return Image.fromarray(corrected.astype(np.uint8), mode='L')
# Node registration
NODE_CLASS_MAPPINGS = {
"DepthEstimationNode": DepthEstimationNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DepthEstimationNode": "Depth Estimation (V2)"
}