Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Image Loading and Saving Functions #15

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 60 additions & 42 deletions lib/train/data/image_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,101 +3,119 @@
from PIL import Image
import numpy as np

davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8)
davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0],
[64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0],
[0, 64, 128], [128, 64, 128]]
# Define the Davis palette for indexed color images
davis_palette = np.repeat(np.expand_dims(np.arange(0, 256), 1), 3, 1).astype(np.uint8)
davis_palette[:22, :] = [
[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128],
[0, 128, 128], [128, 128, 128], [64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0],
[64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128], [0, 64, 0], [128, 64, 0],
[0, 191, 0], [128, 191, 0], [0, 64, 128], [128, 64, 128]
]


def default_image_loader(path):
"""The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader,
but reverts to the opencv_loader if the former is not available."""
"""Attempts to load an image, first using jpeg4py, falling back to opencv if jpeg4py fails."""
if default_image_loader.use_jpeg4py is None:
# Try using jpeg4py
im = jpeg4py_loader(path)
if im is None:
img = jpeg4py_loader(path)
if img is None:
default_image_loader.use_jpeg4py = False
print('Using opencv_loader instead.')
else:
default_image_loader.use_jpeg4py = True
return im
return img

if default_image_loader.use_jpeg4py:
return jpeg4py_loader(path)
return opencv_loader(path)


default_image_loader.use_jpeg4py = None


def jpeg4py_loader(path):
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
"""Loads an image using jpeg4py."""
try:
return jpeg4py.JPEG(path).decode()
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(f"ERROR: Could not read image '{path}' with jpeg4py.")
print(e)
return None


def opencv_loader(path):
""" Read image using opencv's imread function and returns it in rgb format"""
"""Loads an image using OpenCV and returns it in RGB format."""
try:
im = cv.imread(path, cv.IMREAD_COLOR)

# convert to rgb and return
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
img = cv.imread(path, cv.IMREAD_COLOR)
return cv.cvtColor(img, cv.COLOR_BGR2RGB) # Convert to RGB and return
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(f"ERROR: Could not read image '{path}' with OpenCV.")
print(e)
return None


def jpeg4py_loader_w_failsafe(path):
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
"""Attempts to load an image using jpeg4py, falls back to OpenCV if it fails."""
try:
return jpeg4py.JPEG(path).decode()
except:
except Exception as e:
print(f"jpeg4py failed, trying OpenCV for '{path}'.")
try:
im = cv.imread(path, cv.IMREAD_COLOR)

# convert to rgb and return
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
img = cv.imread(path, cv.IMREAD_COLOR)
return cv.cvtColor(img, cv.COLOR_BGR2RGB) # Convert to RGB and return
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(f"ERROR: Could not read image '{path}' with OpenCV either.")
print(e)
return None


def opencv_seg_loader(path):
""" Read segmentation annotation using opencv's imread function"""
"""Loads a segmentation image (annotation) using OpenCV."""
try:
return cv.imread(path)
return cv.imread(path, cv.IMREAD_UNCHANGED) # Use unchanged to read the raw format
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(f"ERROR: Could not read segmentation image '{path}' with OpenCV.")
print(e)
return None


def imread_indexed(filename):
""" Load indexed image with given filename. Used to read segmentation annotations."""

im = Image.open(filename)

annotation = np.atleast_3d(im)[...,0]
return annotation
"""Reads an indexed image (typically used for segmentation annotations)."""
try:
im = Image.open(filename)
annotation = np.atleast_3d(im)[..., 0] # Take the first channel for segmentation
return annotation
except Exception as e:
print(f"ERROR: Could not read indexed image '{filename}'.")
print(e)
return None


def imwrite_indexed(filename, array, color_palette=None):
""" Save indexed image as png. Used to save segmentation annotation."""

"""Saves a 2D array (segmentation annotation) as an indexed PNG image."""
if color_palette is None:
color_palette = davis_palette

if np.atleast_3d(array).shape[2] != 1:
raise Exception("Saving indexed PNGs requires 2D array.")
raise ValueError("Saving indexed PNGs requires a 2D array.")

try:
im = Image.fromarray(array)
im.putpalette(color_palette.ravel()) # Apply the color palette
im.save(filename, format='PNG')
except Exception as e:
print(f"ERROR: Could not save indexed image '{filename}'.")
print(e)


# Example usage of the functions
if __name__ == "__main__":
img_path = 'path_to_image.jpg'
indexed_img_path = 'path_to_indexed_image.png'
seg_img = default_image_loader(img_path)
if seg_img is not None:
print("Image loaded successfully!")

im = Image.fromarray(array)
im.putpalette(color_palette.ravel())
im.save(filename, format='PNG')
# Save an indexed image
segmentation_array = np.zeros((100, 100), dtype=np.uint8) # Example 2D array for segmentation
imwrite_indexed(indexed_img_path, segmentation_array)