diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6447033..fc49494 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -75,7 +75,7 @@ jobs: shell: bash run: | python -m pip install --upgrade pip - pip install -U jupyter_packaging pytest simplejpeg pillow + pip install -U jupyter_packaging pytest simplejpeg pillow opencv-python-headless pip install . rm -rf ./jupyter_rfb ./build ./egg-info - name: Test with pytest diff --git a/jupyter_rfb/_jpg.py b/jupyter_rfb/_jpg.py index 7d80ac5..6874e66 100644 --- a/jupyter_rfb/_jpg.py +++ b/jupyter_rfb/_jpg.py @@ -99,12 +99,35 @@ def _encode(self, array, quality): return f.getvalue() +class OpenCVJpegEncoder(JpegEncoder): + """A JPEG encoder using the OpenCV library.""" + + def __init__(self): + import cv2 + + self.cv2 = cv2 + + def _encode(self, array, quality): + if len(array.shape) == 3 and array.shape[2] == 3: + # Convert RGB to BGR if needed (assume input is RGB) + array = self.cv2.cvtColor(array, self.cv2.COLOR_RGB2BGR) + + # Encode with the specified quality + encode_param = [self.cv2.IMWRITE_JPEG_QUALITY, quality] + success, encoded_image = self.cv2.imencode(".jpg", array, encode_param) + if not success: + raise RuntimeError("OpenCV failed to encode image") + + return encoded_image.tobytes() + + def select_encoder(): """Select an encoder.""" for cls in [ SimpleJpegEncoder, # simplejpeg is fast and lean PillowJpegEncoder, # pillow is commonly available + OpenCVJpegEncoder, # opencv is readily installed in conda environments ]: try: return cls() diff --git a/tests/test_jpg.py b/tests/test_jpg.py index e5b0711..cd549f7 100644 --- a/tests/test_jpg.py +++ b/tests/test_jpg.py @@ -1,6 +1,7 @@ """Test jpg module.""" import numpy as np +import pytest from pytest import raises from jupyter_rfb._jpg import ( @@ -8,6 +9,7 @@ select_encoder, SimpleJpegEncoder, PillowJpegEncoder, + OpenCVJpegEncoder, ) @@ -28,6 +30,7 @@ def test_array2jpg(): def test_simplejpeg_jpeg_encoder(): """Test the simplejpeg encoder.""" + pytest.importorskip("simplejpeg") encoder = SimpleJpegEncoder() _perform_checks(encoder) _perform_error_checks(encoder) @@ -35,11 +38,20 @@ def test_simplejpeg_jpeg_encoder(): def test_pillow_jpeg_encoder(): """Test the pillow encoder.""" + pytest.importorskip("PIL") encoder = PillowJpegEncoder() _perform_checks(encoder) _perform_error_checks(encoder) +def test_opencv_jpeg_encoder(): + """Test the opencv encoder.""" + pytest.importorskip("cv2") + encoder = OpenCVJpegEncoder() + _perform_checks(encoder) + _perform_error_checks(encoder) + + def _perform_checks(encoder): # RGB im = get_random_im(100, 100, 3) @@ -111,9 +123,11 @@ def test_select_encoder(): # Sabotage simple_init = SimpleJpegEncoder.__init__ pillow_init = PillowJpegEncoder.__init__ + cv2_init = OpenCVJpegEncoder.__init__ try: SimpleJpegEncoder.__init__ = lambda self: raise_importerror() PillowJpegEncoder.__init__ = lambda self: raise_importerror() + OpenCVJpegEncoder.__init__ = lambda self: raise_importerror() encoder = select_encoder() assert not isinstance(encoder, (SimpleJpegEncoder, PillowJpegEncoder)) @@ -125,3 +139,4 @@ def test_select_encoder(): finally: SimpleJpegEncoder.__init__ = simple_init PillowJpegEncoder.__init__ = pillow_init + OpenCVJpegEncoder.__init__ = cv2_init