diff --git a/README.md b/README.md index 75d25a9..b6f62e4 100644 --- a/README.md +++ b/README.md @@ -58,20 +58,9 @@ python scripts/download_models.py ### Scripting Demo -This is probably the best starting point if you want to use Cutie in your project. Hopefully, the script is self-explanatory. If not, feel free to open an issue. Run `scripting_demo.py` to see it in action. For more advanced usage, like adding or removing objects, see `scripting_demo_add_del_objects.py`. +This is probably the best starting point if you want to use Cutie in your project. Hopefully, the script is self-explanatory (additional comments in `scripting_demo.py`). If not, feel free to open an issue. For more advanced usage, like adding or removing objects, see `scripting_demo_add_del_objects.py`. ```python -import os - -import torch -from torchvision.transforms.functional import to_tensor -from PIL import Image -import numpy as np - -from cutie.inference.inference_core import InferenceCore -from cutie.utils.get_default_model import get_default_model - - @torch.inference_mode() @torch.cuda.amp.autocast() def main(): @@ -92,15 +81,18 @@ def main(): image = to_tensor(image).cuda().float() if ti == 0: - prediction = processor.step(image, mask, objects=objects) + output_prob = processor.step(image, mask, objects=objects) else: - prediction = processor.step(image) + output_prob = processor.step(image) + + # convert output probabilities to an object mask + mask = processor.output_prob_to_mask(output_prob) # visualize prediction - prediction = torch.argmax(prediction, dim=0) - prediction = Image.fromarray(prediction.cpu().numpy().astype(np.uint8)) - prediction.putpalette(palette) - prediction.show() # or use prediction.save(...) to save it somewhere + mask = Image.fromarray(mask.cpu().numpy().astype(np.uint8)) + mask.putpalette(palette) + mask.show() # or use mask.save(...) to save it somewhere + main() ``` diff --git a/cutie/inference/inference_core.py b/cutie/inference/inference_core.py index 096bf22..35da907 100644 --- a/cutie/inference/inference_core.py +++ b/cutie/inference/inference_core.py @@ -334,3 +334,13 @@ def delete_objects(self, objects: List[int]) -> None: """ self.object_manager.delete_objects(objects) self.memory.purge_except(self.object_manager.all_obj_ids) + + def output_prob_to_mask(self, output_prob: torch.Tensor) -> torch.Tensor: + mask = torch.argmax(output_prob, dim=0) + + # index in tensor != object id -- remap the ids here + new_mask = torch.zeros_like(mask) + for tmp_id, obj in self.object_manager.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + + return new_mask diff --git a/scripting_demo.py b/scripting_demo.py index 20c8817..6eb646a 100644 --- a/scripting_demo.py +++ b/scripting_demo.py @@ -12,32 +12,53 @@ @torch.inference_mode() @torch.cuda.amp.autocast() def main(): - + # obtain the Cutie model with default parameters -- skipping hydra configuration cutie = get_default_model() + # Typically, use one InferenceCore per video processor = InferenceCore(cutie, cfg=cutie.cfg) image_path = './examples/images/bike' - images = sorted(os.listdir(image_path)) # ordering is important + # ordering is important + images = sorted(os.listdir(image_path)) + + # mask for the first frame + # NOTE: this should be a grayscale mask or a indexed (with/without palette) mask, + # and definitely NOT a colored RGB image + # https://pillow.readthedocs.io/en/stable/handbook/concepts.html: mode "L" or "P" mask = Image.open('./examples/masks/bike/00000.png') + assert mask.mode in ['L', 'P'] + + # palette is for visualization palette = mask.getpalette() + + # the number of objects is determined by counting the unique values in the mask + # common mistake: if the mask is resized w/ interpolation, there might be new unique values objects = np.unique(np.array(mask)) - objects = objects[objects != 0].tolist() # background "0" does not count as an object + # background "0" does not count as an object + objects = objects[objects != 0].tolist() + mask = torch.from_numpy(np.array(mask)).cuda() for ti, image_name in enumerate(images): + # load the image as RGB; normalization is done within the model image = Image.open(os.path.join(image_path, image_name)) image = to_tensor(image).cuda().float() if ti == 0: - prediction = processor.step(image, mask, objects=objects) + # if mask is passed in, it is memorized + # if not all objects are specified, we propagate the unspecified objects using memory + output_prob = processor.step(image, mask, objects=objects) else: - prediction = processor.step(image) + # otherwise, we propagate the mask from memory + output_prob = processor.step(image) + + # convert output probabilities to an object mask + mask = processor.output_prob_to_mask(output_prob) # visualize prediction - prediction = torch.argmax(prediction, dim=0) - prediction = Image.fromarray(prediction.cpu().numpy().astype(np.uint8)) - prediction.putpalette(palette) - prediction.show() # or use prediction.save(...) to save it somewhere + mask = Image.fromarray(mask.cpu().numpy().astype(np.uint8)) + mask.putpalette(palette) + mask.show() # or use mask.save(...) to save it somewhere main() diff --git a/scripting_demo_add_del_objects.py b/scripting_demo_add_del_objects.py index fb99051..3904faa 100644 --- a/scripting_demo_add_del_objects.py +++ b/scripting_demo_add_del_objects.py @@ -12,15 +12,18 @@ @torch.inference_mode() @torch.cuda.amp.autocast() def main(): - + # obtain the Cutie model with default parameters -- skipping hydra configuration cutie = get_default_model() + # Typically, use one InferenceCore per video processor = InferenceCore(cutie, cfg=cutie.cfg) image_path = './examples/images/judo' mask_path = './examples/masks/judo' - images = sorted(os.listdir(image_path)) # ordering is important + # ordering is important + images = sorted(os.listdir(image_path)) for ti, image_name in enumerate(images): + # load the image as RGB; normalization is done within the model image = Image.open(os.path.join(image_path, image_name)) image = to_tensor(image).cuda().float() @@ -29,27 +32,35 @@ def main(): processor.delete_objects([1]) mask_name = image_name[:-4] + '.png' + + # we pass the mask in if it exists if os.path.exists(os.path.join(mask_path, mask_name)): - # add the objects in the mask + # NOTE: this should be a grayscale mask or a indexed (with/without palette) mask, + # and definitely NOT a colored RGB image + # https://pillow.readthedocs.io/en/stable/handbook/concepts.html: mode "L" or "P" mask = Image.open(os.path.join(mask_path, mask_name)) + + # palette is for visualization palette = mask.getpalette() + + # the number of objects is determined by counting the unique values in the mask + # common mistake: if the mask is resized w/ interpolation, there might be new unique values objects = np.unique(np.array(mask)) - objects = objects[objects != 0].tolist() # background "0" does not count as an object + # background "0" does not count as an object + objects = objects[objects != 0].tolist() mask = torch.from_numpy(np.array(mask)).cuda() - prediction = processor.step(image, mask, objects=objects) + # if mask is passed in, it is memorized + # if not all objects are specified, we propagate the unspecified objects using memory + output_prob = processor.step(image, mask, objects=objects) else: - prediction = processor.step(image) + # otherwise, we propagate the mask from memory + output_prob = processor.step(image) - # visualize prediction - mask = torch.argmax(prediction, dim=0) - - # since the objects might shift in the channel dim due to deletion, remap the ids - new_mask = torch.zeros_like(mask) - for tmp_id, obj in processor.object_manager.tmp_id_to_obj.items(): - new_mask[mask == tmp_id] = obj.id - mask = new_mask + # convert output probabilities to an object mask + mask = processor.output_prob_to_mask(output_prob) + # visualize prediction mask = Image.fromarray(mask.cpu().numpy().astype(np.uint8)) mask.putpalette(palette) # mask.show() # or use prediction.save(...) to save it somewhere