diff --git a/icevision/visualize/draw_data.py b/icevision/visualize/draw_data.py index 97ffeb6b4..420818db7 100644 --- a/icevision/visualize/draw_data.py +++ b/icevision/visualize/draw_data.py @@ -368,7 +368,7 @@ def draw_bbox( # Calculate image dimensions dims = sorted(img.shape, reverse=True) color = as_rgb_tuple(color) - img = PIL.Image.fromarray(img) + img = PIL.Image.fromarray(img.squeeze()).convert("RGB") draw = PIL.ImageDraw.Draw(img) # corner thickness is linearly correlated with the smaller image dimension. @@ -482,7 +482,7 @@ def draw_mask( raise ValueError( f"`border_thickness` must be an odd number. You entered {border_thickness}" ) - img = PIL.Image.fromarray(img) + img = PIL.Image.fromarray(img.squeeze()) w, h = img.size mask_idxs = np.where(mask.data) @@ -520,7 +520,7 @@ def draw_keypoints( # calculate scaling for points and connections img_h, img_w, _ = img.shape img_area = img_h * img_w - img = PIL.Image.fromarray(img) + img = PIL.Image.fromarray(img.squeeze()) draw = PIL.ImageDraw.Draw(img) dynamic_size = int(0.01867599 * (img_area ** 0.4422045)) dynamic_size = max(dynamic_size, 1)