diff --git a/README.md b/README.md index 0477814..75d25a9 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ python scripts/download_models.py ### Scripting Demo -(See also scripting_demo.py) +This is probably the best starting point if you want to use Cutie in your project. Hopefully, the script is self-explanatory. If not, feel free to open an issue. Run `scripting_demo.py` to see it in action. For more advanced usage, like adding or removing objects, see `scripting_demo_add_del_objects.py`. ```python import os diff --git a/cutie/inference/inference_core.py b/cutie/inference/inference_core.py index 93caf26..096bf22 100644 --- a/cutie/inference/inference_core.py +++ b/cutie/inference/inference_core.py @@ -16,6 +16,7 @@ class InferenceCore: + def __init__(self, network: CUTIE, cfg: DictConfig, @@ -327,55 +328,9 @@ def step(self, return output_prob - def get_aux_outputs(self, image: torch.Tensor) -> Dict[str, torch.Tensor]: - image, pads = pad_divide_by(image, 16) - image = image.unsqueeze(0) # add the batch dimension - _, pix_feat = self.image_feature_store.get_features(self.curr_ti, image) - - aux_inputs = self.memory.aux - aux_outputs = self.network.compute_aux(pix_feat, aux_inputs, selector=None) - aux_outputs['q_weights'] = aux_inputs['q_weights'] - aux_outputs['p_weights'] = aux_inputs['p_weights'] - - for k, v in aux_outputs.items(): - if len(v.shape) == 5: - aux_outputs[k] = F.interpolate(v[0], - size=image.shape[-2:], - mode='bilinear', - align_corners=False) - elif 'weights' in k: - b, num_objects, num_heads, num_queries, h, w = v.shape - v = v.view(num_objects * num_heads, num_queries, h, w) - v = F.interpolate(v, size=image.shape[-2:], mode='bilinear', align_corners=False) - aux_outputs[k] = v.view(num_objects, num_heads, num_queries, *image.shape[-2:]) - else: - aux_outputs[k] = F.interpolate(v, - size=image.shape[-2:], - mode='bilinear', - align_corners=False)[0] - aux_outputs[k] = unpad(aux_outputs[k], pads) - if 'weights' in k: - weights = aux_outputs[k] - weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0] + - 1e-8) - aux_outputs[k] = (weights * 255).cpu().numpy() - else: - aux_outputs[k] = (aux_outputs[k].softmax(dim=0) * 255).cpu().numpy() - - self.image_feature_store.delete(self.curr_ti) - return aux_outputs - - def get_aux_object_weights(self, image: torch.Tensor) -> np.ndarray: - image, pads = pad_divide_by(image, 16) - # B*num_objects*H*W*num_queries -> num_objects*num_queries*H*W - # weights = F.softmax(self.obj_logits, dim=-1)[0] - weights = F.sigmoid(self.obj_logits)[0] - weights = weights.permute(0, 3, 1, 2).contiguous() - weights = F.interpolate(weights, - size=image.shape[-2:], - mode='bilinear', - align_corners=False) - # weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0]) - weights = unpad(weights, pads) - weights = (weights * 255).cpu().numpy() - return weights + def delete_objects(self, objects: List[int]) -> None: + """ + Delete the given objects from the memory. + """ + self.object_manager.delete_objects(objects) + self.memory.purge_except(self.object_manager.all_obj_ids) diff --git a/cutie/inference/object_manager.py b/cutie/inference/object_manager.py index 93a60ee..edec93b 100644 --- a/cutie/inference/object_manager.py +++ b/cutie/inference/object_manager.py @@ -10,6 +10,7 @@ class ObjectManager: Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. Temporary IDs start from 1. """ + def __init__(self): self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} @@ -52,7 +53,7 @@ def add_new_objects( assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) return corresponding_tmp_ids, corresponding_obj_ids - def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None: + def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None: # delete an object or a list of objects # re-sort the tmp ids if isinstance(obj_ids_to_remove, int): @@ -93,7 +94,7 @@ def purge_inactive_objects(self, purge_activated = len(obj_id_to_be_deleted) > 0 if purge_activated: - self.delete_object(obj_id_to_be_deleted) + self.delete_objects(obj_id_to_be_deleted) return purge_activated, tmp_id_to_keep, obj_id_to_keep def tmp_to_obj_cls(self, mask) -> torch.Tensor: diff --git a/examples/images/judo/00000.jpg b/examples/images/judo/00000.jpg new file mode 100755 index 0000000..ad1d780 Binary files /dev/null and b/examples/images/judo/00000.jpg differ diff --git a/examples/images/judo/00001.jpg b/examples/images/judo/00001.jpg new file mode 100755 index 0000000..b2e32af Binary files /dev/null and b/examples/images/judo/00001.jpg differ diff --git a/examples/images/judo/00002.jpg b/examples/images/judo/00002.jpg new file mode 100755 index 0000000..776f24b Binary files /dev/null and b/examples/images/judo/00002.jpg differ diff --git a/examples/images/judo/00003.jpg b/examples/images/judo/00003.jpg new file mode 100755 index 0000000..494b2d9 Binary files /dev/null and b/examples/images/judo/00003.jpg differ diff --git a/examples/images/judo/00004.jpg b/examples/images/judo/00004.jpg new file mode 100755 index 0000000..c8c8484 Binary files /dev/null and b/examples/images/judo/00004.jpg differ diff --git a/examples/images/judo/00005.jpg b/examples/images/judo/00005.jpg new file mode 100755 index 0000000..90f07f9 Binary files /dev/null and b/examples/images/judo/00005.jpg differ diff --git a/examples/images/judo/00006.jpg b/examples/images/judo/00006.jpg new file mode 100755 index 0000000..aada315 Binary files /dev/null and b/examples/images/judo/00006.jpg differ diff --git a/examples/images/judo/00007.jpg b/examples/images/judo/00007.jpg new file mode 100755 index 0000000..24f7e90 Binary files /dev/null and b/examples/images/judo/00007.jpg differ diff --git a/examples/images/judo/00008.jpg b/examples/images/judo/00008.jpg new file mode 100755 index 0000000..c766995 Binary files /dev/null and b/examples/images/judo/00008.jpg differ diff --git a/examples/images/judo/00009.jpg b/examples/images/judo/00009.jpg new file mode 100755 index 0000000..42ab714 Binary files /dev/null and b/examples/images/judo/00009.jpg differ diff --git a/examples/images/judo/00010.jpg b/examples/images/judo/00010.jpg new file mode 100755 index 0000000..2020660 Binary files /dev/null and b/examples/images/judo/00010.jpg differ diff --git a/examples/images/judo/00011.jpg b/examples/images/judo/00011.jpg new file mode 100755 index 0000000..20cf853 Binary files /dev/null and b/examples/images/judo/00011.jpg differ diff --git a/examples/images/judo/00012.jpg b/examples/images/judo/00012.jpg new file mode 100755 index 0000000..d4d91f0 Binary files /dev/null and b/examples/images/judo/00012.jpg differ diff --git a/examples/images/judo/00013.jpg b/examples/images/judo/00013.jpg new file mode 100755 index 0000000..b37ae3a Binary files /dev/null and b/examples/images/judo/00013.jpg differ diff --git a/examples/images/judo/00014.jpg b/examples/images/judo/00014.jpg new file mode 100755 index 0000000..e146c03 Binary files /dev/null and b/examples/images/judo/00014.jpg differ diff --git a/examples/images/judo/00015.jpg b/examples/images/judo/00015.jpg new file mode 100755 index 0000000..f4335f2 Binary files /dev/null and b/examples/images/judo/00015.jpg differ diff --git a/examples/masks/judo/00000.png b/examples/masks/judo/00000.png new file mode 100644 index 0000000..8fb9ec9 Binary files /dev/null and b/examples/masks/judo/00000.png differ diff --git a/examples/masks/judo/00005.png b/examples/masks/judo/00005.png new file mode 100644 index 0000000..27e1cae Binary files /dev/null and b/examples/masks/judo/00005.png differ diff --git a/examples/masks/judo/00008.png b/examples/masks/judo/00008.png new file mode 100644 index 0000000..b6952eb Binary files /dev/null and b/examples/masks/judo/00008.png differ diff --git a/examples/masks/judo/00013.png b/examples/masks/judo/00013.png new file mode 100644 index 0000000..71b5de8 Binary files /dev/null and b/examples/masks/judo/00013.png differ diff --git a/scripting_demo_add_del_objects.py b/scripting_demo_add_del_objects.py new file mode 100644 index 0000000..fb99051 --- /dev/null +++ b/scripting_demo_add_del_objects.py @@ -0,0 +1,59 @@ +import os + +import torch +from torchvision.transforms.functional import to_tensor +from PIL import Image +import numpy as np + +from cutie.inference.inference_core import InferenceCore +from cutie.utils.get_default_model import get_default_model + + +@torch.inference_mode() +@torch.cuda.amp.autocast() +def main(): + + cutie = get_default_model() + processor = InferenceCore(cutie, cfg=cutie.cfg) + + image_path = './examples/images/judo' + mask_path = './examples/masks/judo' + images = sorted(os.listdir(image_path)) # ordering is important + + for ti, image_name in enumerate(images): + image = Image.open(os.path.join(image_path, image_name)) + image = to_tensor(image).cuda().float() + + # deleting the red mask at time step 10 for no reason -- you can set your own condition + if ti == 10: + processor.delete_objects([1]) + + mask_name = image_name[:-4] + '.png' + if os.path.exists(os.path.join(mask_path, mask_name)): + # add the objects in the mask + mask = Image.open(os.path.join(mask_path, mask_name)) + palette = mask.getpalette() + objects = np.unique(np.array(mask)) + objects = objects[objects != 0].tolist() # background "0" does not count as an object + mask = torch.from_numpy(np.array(mask)).cuda() + + prediction = processor.step(image, mask, objects=objects) + else: + prediction = processor.step(image) + + # visualize prediction + mask = torch.argmax(prediction, dim=0) + + # since the objects might shift in the channel dim due to deletion, remap the ids + new_mask = torch.zeros_like(mask) + for tmp_id, obj in processor.object_manager.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + mask = new_mask + + mask = Image.fromarray(mask.cpu().numpy().astype(np.uint8)) + mask.putpalette(palette) + # mask.show() # or use prediction.save(...) to save it somewhere + mask.save(os.path.join('./examples', mask_name)) + + +main()