Skip to content

Commit

Permalink
Fix annotation of draw_segmentation_masks (pytorch#4527)
Browse files Browse the repository at this point in the history
* Add str param

* Update test to include str

* Fix mypy

* Remove a small bracket

* Test more robustly

* Update docstring and test:

* Apply suggestions from code review

Co-authored-by: Nicolas Hug <[email protected]>

* Update torchvision/utils.py

Small docstring fix

* Update torchvision/utils.py

* remove unnecessary renaming

Co-authored-by: Nicolas Hug <[email protected]>
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
3 people authored and cyyever committed Nov 16, 2021
1 parent 43d5aa2 commit 6e021d1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
5 changes: 5 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def test_draw_invalid_boxes():
"colors",
[
None,
"blue",
"#FF00FF",
(1, 34, 122),
["red", "blue"],
["#FF00FF", (1, 34, 122)],
],
Expand Down Expand Up @@ -191,6 +194,8 @@ def test_draw_segmentation_masks(colors, alpha):

if colors is None:
colors = utils._generate_color_palette(num_masks)
elif isinstance(colors, str) or isinstance(colors, tuple):
colors = [colors]

# Make sure each mask draws with its own color
for mask, color in zip(masks, colors):
Expand Down
21 changes: 10 additions & 11 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def draw_bounding_boxes(
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
labels (List[str]): List containing the labels of bounding boxes.
colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors
or a single color for all of the bounding boxes. The colors can be represented as `str` or
`Tuple[int, int, int]`.
colors (color or list of colors, optional): List containing the colors
of the boxes or single color for all boxes. The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
fill (bool): If `True` fills the bounding box with specified color.
width (int): Width of bounding box.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
Expand Down Expand Up @@ -231,7 +231,7 @@ def draw_segmentation_masks(
image: torch.Tensor,
masks: torch.Tensor,
alpha: float = 0.8,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
) -> torch.Tensor:

"""
Expand All @@ -243,10 +243,10 @@ def draw_segmentation_masks(
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency.
colors (list or None): List containing the colors of the masks. The colors can
be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list
with one element. By default, random colors are generated for each mask.
colors (color or list of colors, optional): List containing the colors
of the masks or single color for all masks. The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
By default, random colors are generated for each mask.
Returns:
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
Expand Down Expand Up @@ -289,8 +289,7 @@ def draw_segmentation_masks(
for color in colors:
if isinstance(color, str):
color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=out_dtype)
colors_.append(color)
colors_.append(torch.tensor(color, dtype=out_dtype))

img_to_draw = image.detach().clone()
# TODO: There might be a way to vectorize this
Expand All @@ -301,6 +300,6 @@ def draw_segmentation_masks(
return out.to(out_dtype)


def _generate_color_palette(num_masks):
def _generate_color_palette(num_masks: int):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
return [tuple((i * palette) % 255) for i in range(num_masks)]

0 comments on commit 6e021d1

Please sign in to comment.