-
Notifications
You must be signed in to change notification settings - Fork 76
/
scripting_demo.py
67 lines (52 loc) · 2.46 KB
/
scripting_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import torch
from torchvision.transforms.functional import to_tensor
from PIL import Image
import numpy as np
from cutie.inference.inference_core import InferenceCore
from cutie.utils.get_default_model import get_default_model
@torch.inference_mode()
@torch.cuda.amp.autocast()
def main():
# obtain the Cutie model with default parameters -- skipping hydra configuration
cutie = get_default_model()
# Typically, use one InferenceCore per video
processor = InferenceCore(cutie, cfg=cutie.cfg)
# the processor matches the shorter edge of the input to this size
# you might want to experiment with different sizes, -1 keeps the original size
processor.max_internal_size = 480
image_path = './examples/images/bike'
# ordering is important
images = sorted(os.listdir(image_path))
# mask for the first frame
# NOTE: this should be a grayscale mask or a indexed (with/without palette) mask,
# and definitely NOT a colored RGB image
# https://pillow.readthedocs.io/en/stable/handbook/concepts.html: mode "L" or "P"
mask = Image.open('./examples/masks/bike/00000.png')
assert mask.mode in ['L', 'P']
# palette is for visualization
palette = mask.getpalette()
# the number of objects is determined by counting the unique values in the mask
# common mistake: if the mask is resized w/ interpolation, there might be new unique values
objects = np.unique(np.array(mask))
# background "0" does not count as an object
objects = objects[objects != 0].tolist()
mask = torch.from_numpy(np.array(mask)).cuda()
for ti, image_name in enumerate(images):
# load the image as RGB; normalization is done within the model
image = Image.open(os.path.join(image_path, image_name))
image = to_tensor(image).cuda().float()
if ti == 0:
# if mask is passed in, it is memorized
# if not all objects are specified, we propagate the unspecified objects using memory
output_prob = processor.step(image, mask, objects=objects)
else:
# otherwise, we propagate the mask from memory
output_prob = processor.step(image)
# convert output probabilities to an object mask
mask = processor.output_prob_to_mask(output_prob)
# visualize prediction
mask = Image.fromarray(mask.cpu().numpy().astype(np.uint8))
mask.putpalette(palette)
mask.show() # or use mask.save(...) to save it somewhere
main()