Skip to content

Commit

Permalink
RGBA support
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchengrex committed Feb 17, 2024
1 parent 048e6b1 commit f330f5c
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 17 deletions.
4 changes: 2 additions & 2 deletions gui/TIPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ Controls:

Visualizations:

- Middle-click on target objects to toggle some visualization effects (for layered, popout, and binary mask export).
- Middle-click on target objects to toggle some visualization effects (for layered, popout, RGBA, and binary mask export).
- Soft masks are only saved for the "propagated" frames, not for the interacted frames. To save all frames, utilize forward and backward propagation.
- For some visualizations, the images saved during propagation will have higher quality with soft edges. This is because we have access to the soft mask only during propagation.
- For some visualizations (layered and RGBA), the images saved during propagation will be higher quality with soft edges. This is because we have access to the soft mask only during propagation. Set the save visualization mode to "Propagation only" to only save during propagation.
- The "layered" visualization mode inserts an RGBA layer between the foreground and the background. Use "import layer" to select a new layer.

Exporting:
Expand Down
22 changes: 18 additions & 4 deletions gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,17 @@ def __init__(self, controller, cfg: DictConfig) -> None:
self.combo.addItem("light")
self.combo.addItem("popup")
self.combo.addItem("layer")
self.combo.addItem("rgba")
self.combo.setCurrentText('davis')
self.combo.currentTextChanged.connect(controller.set_vis_mode)

self.save_visualization_checkbox = QCheckBox(self)
self.save_visualization_checkbox.toggled.connect(controller.on_save_visualization_toggle)
self.save_visualization_checkbox.setChecked(False)
self.save_visualization_combo = QComboBox(self)
self.save_visualization_combo.addItem("None")
self.save_visualization_combo.addItem("Always")
self.save_visualization_combo.addItem("Propagation only (higher quality)")
self.combo.setCurrentText('None')
self.save_visualization_combo.currentTextChanged.connect(
controller.on_set_save_visualization_mode)

self.save_soft_mask_checkbox = QCheckBox(self)
self.save_soft_mask_checkbox.toggled.connect(controller.on_save_soft_mask_toggle)
Expand Down Expand Up @@ -230,7 +235,7 @@ def __init__(self, controller, cfg: DictConfig) -> None:
overlay_topbox.addWidget(self.save_soft_mask_checkbox)
overlay_topbox.addWidget(self.export_binary_button)
overlay_botbox.addWidget(QLabel('Save visualization'))
overlay_botbox.addWidget(self.save_visualization_checkbox)
overlay_botbox.addWidget(self.save_visualization_combo)
overlay_botbox.addWidget(self.export_video_button)
overlay_botbox.addWidget(QLabel('Output FPS: '))
overlay_botbox.addWidget(self.fps_dial)
Expand Down Expand Up @@ -327,6 +332,15 @@ def text(self, text):

def set_canvas(self, image):
height, width, channel = image.shape
# if the image is RGBA, convert to RGB first by coloring the background green
if channel == 4:
image_rgb = image[:, :, :3].copy()
alpha = image[:, :, 3].astype(np.float32) / 255
green_bg = np.array([0, 255, 0])
# soft blending
image = (image_rgb * alpha[:, :, np.newaxis] + green_bg[np.newaxis, np.newaxis, :] *
(1 - alpha[:, :, np.newaxis])).astype(np.uint8)

bytesPerLine = 3 * width

qImg = QImage(image.data, width, height, bytesPerLine, QImage.Format.Format_RGB888)
Expand Down
33 changes: 29 additions & 4 deletions gui/interactive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def index_numpy_to_one_hot_torch(mask: np.ndarray, num_classes: int):
grayscale_weights_torch = torch.from_numpy(grayscale_weights).to(device).unsqueeze(0)


def get_visualization(mode: Literal['image', 'mask', 'fade', 'davis', 'light', 'popup',
'layer'], image: np.ndarray, mask: np.ndarray,
layer: np.ndarray, target_objects: List[int]) -> np.ndarray:
def get_visualization(mode: Literal['image', 'mask', 'fade', 'davis', 'light', 'popup', 'layer',
'rgba'], image: np.ndarray, mask: np.ndarray, layer: np.ndarray,
target_objects: List[int]) -> np.ndarray:
if mode == 'image':
return image
elif mode == 'mask':
Expand All @@ -70,12 +70,14 @@ def get_visualization(mode: Literal['image', 'mask', 'fade', 'davis', 'light', '
return overlay_davis(image, mask)
else:
return overlay_layer(image, mask, layer, target_objects)
elif mode == 'rgba':
return overlay_rgba(image, mask, target_objects)
else:
raise NotImplementedError


def get_visualization_torch(mode: Literal['image', 'mask', 'fade', 'davis', 'light', 'popup',
'layer'], image: torch.Tensor, prob: torch.Tensor,
'layer', 'rgba'], image: torch.Tensor, prob: torch.Tensor,
layer: torch.Tensor, target_objects: List[int]) -> np.ndarray:
if mode == 'image':
return image
Expand All @@ -96,6 +98,8 @@ def get_visualization_torch(mode: Literal['image', 'mask', 'fade', 'davis', 'lig
return overlay_davis_torch(image, prob)
else:
return overlay_layer_torch(image, prob, layer, target_objects)
elif mode == 'rgba':
return overlay_rgba_torch(image, prob, target_objects)
else:
raise NotImplementedError

Expand Down Expand Up @@ -138,6 +142,13 @@ def overlay_layer(image: np.ndarray, mask: np.ndarray, layer: np.ndarray,
return im_overlay.astype(image.dtype)


def overlay_rgba(image: np.ndarray, mask: np.ndarray, target_objects: List[int]):
# Put the mask is in the alpha channel
obj_mask = (np.isin(mask, target_objects)).astype(np.float32)[:, :, np.newaxis] * 255
im_overlay = np.concatenate([image, obj_mask], axis=-1)
return im_overlay.astype(image.dtype)


def overlay_davis_torch(image: torch.Tensor,
prob: torch.Tensor,
alpha: float = 0.5,
Expand Down Expand Up @@ -202,3 +213,17 @@ def overlay_layer_torch(image: torch.Tensor, prob: torch.Tensor, layer: torch.Te

im_overlay = (im_overlay * 255).byte().cpu().numpy()
return im_overlay


def overlay_rgba_torch(image: torch.Tensor, prob: torch.Tensor, target_objects: List[int]):
image = image.permute(1, 2, 0)

if len(target_objects) == 0:
obj_mask = torch.zeros_like(prob[0]).unsqueeze(2)
else:
# TODO: figure out why we need to convert this to numpy array
obj_mask = prob[np.array(target_objects, dtype=np.int32)].sum(0).unsqueeze(2)

im_overlay = torch.cat([image, obj_mask], dim=-1).clip(0, 1)
im_overlay = (im_overlay * 255).byte().cpu().numpy()
return im_overlay
13 changes: 8 additions & 5 deletions gui/main_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, cfg: DictConfig) -> None:
# visualization info
self.vis_mode: str = 'davis'
self.vis_image: np.ndarray = None
self.save_visualization: bool = False
self.save_visualization_mode: str = 'None'
self.save_soft_mask: bool = False

self.interacted_prob: torch.Tensor = None
Expand Down Expand Up @@ -219,7 +219,10 @@ def update_current_image_fast(self, invalid_soft_mask: bool = False):
self.vis_target_objects)
self.curr_image_torch = None
self.vis_image = np.ascontiguousarray(self.vis_image)
if self.save_visualization and not invalid_soft_mask:
save_visualization = self.save_visualization_mode in [
'Propagation only (higher quality)', 'Always'
]
if save_visualization and not invalid_soft_mask:
self.res_man.save_visualization(self.curr_ti, self.vis_mode, self.vis_image)
if self.save_soft_mask and not invalid_soft_mask:
self.res_man.save_soft_mask(self.curr_ti, self.curr_prob.cpu().numpy())
Expand All @@ -231,7 +234,7 @@ def show_current_frame(self, fast: bool = False, invalid_soft_mask: bool = False
self.update_current_image_fast(invalid_soft_mask)
else:
self.compose_current_im()
if self.save_visualization:
if self.save_visualization_mode == 'Always':
self.res_man.save_visualization(self.curr_ti, self.vis_mode, self.vis_image)
self.update_canvas()

Expand Down Expand Up @@ -595,8 +598,8 @@ def _try_load_layer(self, file_name):
except FileNotFoundError:
self.gui.text(f'{file_name} not found.')

def on_save_visualization_toggle(self):
self.save_visualization = self.gui.save_visualization_checkbox.isChecked()
def on_set_save_visualization_mode(self):
self.save_visualization_mode = self.gui.save_visualization_combo.currentText()

def on_save_soft_mask_toggle(self):
self.save_soft_mask = self.gui.save_soft_mask_checkbox.isChecked()
Expand Down
12 changes: 10 additions & 2 deletions gui/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# https://bugs.python.org/issue28178
# ah python ah why
class LRU:

def __init__(self, func, maxsize=128):
self.cache = collections.OrderedDict()
self.func = func
Expand Down Expand Up @@ -51,6 +52,7 @@ class SaveItem:


class ResourceManager:

def __init__(self, cfg: DictConfig):
# determine inputs
images = cfg['images']
Expand Down Expand Up @@ -152,9 +154,15 @@ def save_thread(self, queue: Queue):
elif args.type.startswith('visualization'):
# numpy array, save with cv2
vis_mode = args.type.split('_')[-1]
data = cv2.cvtColor(args.data, cv2.COLOR_RGB2BGR)
os.makedirs(path.join(self.visualization_dir, vis_mode), exist_ok=True)
cv2.imwrite(path.join(self.visualization_dir, vis_mode, args.name + '.jpg'), data)
if vis_mode == 'rgba':
data = cv2.cvtColor(args.data, cv2.COLOR_RGBA2BGRA).copy()
cv2.imwrite(path.join(self.visualization_dir, vis_mode, args.name + '.png'),
data)
else:
data = cv2.cvtColor(args.data, cv2.COLOR_RGB2BGR)
cv2.imwrite(path.join(self.visualization_dir, vis_mode, args.name + '.jpg'),
data)
elif args.type == 'soft_mask':
# numpy array, save each channel with cv2
num_channels = args.data.shape[0]
Expand Down

0 comments on commit f330f5c

Please sign in to comment.