diff --git a/src/decode.rs b/src/decode.rs index 56ca34c..c56930d 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -24,9 +24,13 @@ struct ImageInfo { } impl ImageInfo { - fn from(item: Metadata) -> ImageInfo { + fn from(item: Metadata, pixel: &Data) -> ImageInfo { + let pixel_type = match &pixel { + Data::Pixels(pixels) => Some(pixels), + Data::Jpeg(_) => None, + }; ImageInfo { - mode: Self::mode(item.num_color_channels, item.has_alpha_channel), + mode: Self::mode(item.num_color_channels, item.has_alpha_channel, pixel_type), width: item.width, height: item.height, num_channels: item.num_color_channels, @@ -34,39 +38,26 @@ impl ImageInfo { } } - fn mode(num_channels: u32, has_alpha_channel: bool) -> String { - match (num_channels, has_alpha_channel) { + fn mode(num_channels: u32, has_alpha_channel: bool, pixel_type: Option<&Pixels>) -> String { + let mode = match (num_channels, has_alpha_channel) { (1, false) => "L".to_string(), (1, true) => "LA".to_string(), (3, false) => "RGB".to_string(), (3, true) => "RGBA".to_string(), _ => panic!("Unsupported number of channels"), - } - } -} - -pub fn convert_pixels(pixels: Pixels) -> Vec { - let mut result = Vec::new(); - match pixels { - Pixels::Uint8(pixels) => { - for pixel in pixels { - result.push(pixel); - } - } - Pixels::Uint16(pixels) => { - for pixel in pixels { - result.push((pixel >> 8) as u8); - result.push(pixel as u8); + }; + if let Some(Pixels::Uint16(_)) = pixel_type { + if mode == "L" { + return "I;16".to_string(); } } - Pixels::Float(pixels) => { - for pixel in pixels { - result.push((pixel * 255.0) as u8); + if let Some(Pixels::Float(_)) = pixel_type { + if mode == "L" { + return "F".to_string(); } } - Pixels::Float16(_) => panic!("Float16 is not supported yet"), + mode } - result } #[pyclass(module = "pillow_jxl")] @@ -96,6 +87,67 @@ impl Decoder { } } +impl Decoder { + fn pixels_to_bytes_8bit(&self, pixels: Pixels) -> Vec { + // Convert pixels to bytes with 8-bit casting + let mut result = Vec::new(); + match pixels { + Pixels::Uint8(pixels) => { + return pixels; + } + Pixels::Uint16(pixels) => { + for pixel in pixels { + result.push((pixel >> 8) as u8); + } + } + Pixels::Float(pixels) => { + for pixel in pixels { + result.push((pixel * 255.0) as u8); + } + } + Pixels::Float16(_) => panic!("Float16 is not supported yet"), + } + result + } + + fn pixels_to_bytes(&self, pixels: Pixels) -> Vec { + // Convert pixels to bytes without casting + let mut result = Vec::new(); + match pixels { + Pixels::Uint8(pixels) => { + return pixels; + } + Pixels::Uint16(pixels) => { + for pixel in pixels { + let pix_bytes = pixel.to_ne_bytes(); + for byte in pix_bytes.iter() { + result.push(*byte); + } + } + } + Pixels::Float(pixels) => { + for pixel in pixels { + let pix_bytes = pixel.to_ne_bytes(); + for byte in pix_bytes.iter() { + result.push(*byte); + } + } + } + Pixels::Float16(_) => panic!("Float16 is not supported yet"), + } + result + } + + fn convert_pil_pixels(&self, pixels: Pixels, num_channels: u32) -> Vec { + let result = match num_channels { + 1 => self.pixels_to_bytes(pixels), + 3 => self.pixels_to_bytes_8bit(pixels), + _ => panic!("Unsupported number of channels"), + }; + result + } +} + impl Decoder { fn call_inner(&self, data: &[u8]) -> PyResult<(bool, ImageInfo, Cow<'_, [u8]>, Cow<'_, [u8]>)> { let parallel_runner = ThreadsRunner::new( @@ -113,20 +165,16 @@ impl Decoder { .build() .map_err(to_pyjxlerror)?; let (info, img) = decoder.reconstruct(&data).map_err(to_pyjxlerror)?; - let (jpeg, img) = match img { - Data::Jpeg(x) => (true, x), - Data::Pixels(x) => (false, convert_pixels(x)), - }; let icc_profile: Vec = match &info.icc_profile { Some(x) => x.to_vec(), None => Vec::new(), }; - Ok(( - jpeg, - ImageInfo::from(info), - Cow::Owned(img), - Cow::Owned(icc_profile), - )) + let img_info = ImageInfo::from(info, &img); + let (jpeg, img) = match img { + Data::Jpeg(x) => (true, x), + Data::Pixels(x) => (false, self.convert_pil_pixels(x, img_info.num_channels)), + }; + Ok((jpeg, img_info, Cow::Owned(img), Cow::Owned(icc_profile))) } } diff --git a/test/images/sample_grey.jxl b/test/images/sample_grey.jxl new file mode 100644 index 0000000..c9d8505 Binary files /dev/null and b/test/images/sample_grey.jxl differ diff --git a/test/images/sample_grey.png b/test/images/sample_grey.png new file mode 100644 index 0000000..ebf7997 Binary files /dev/null and b/test/images/sample_grey.png differ diff --git a/test/test_plugin.py b/test/test_plugin.py index a94db8a..bf9345b 100644 --- a/test/test_plugin.py +++ b/test/test_plugin.py @@ -7,12 +7,25 @@ def test_decode(): - img = Image.open("test/images/sample.jxl") + img_jxl = Image.open("test/images/sample.jxl") + img_png = Image.open("test/images/sample.png") - assert img.size == (40, 50) - assert img.mode == "RGBA" - assert not img.is_animated - assert img.n_frames == 1 + assert img_jxl.size == img_png.size + assert img_jxl.mode == img_png.mode == "RGBA" + assert not img_jxl.is_animated + assert img_jxl.n_frames == 1 + assert list(img_jxl.getdata()) == list(img_png.getdata()) + + +def test_decode_I16(): + img_jxl = Image.open("test/images/sample_grey.jxl") + img_png = Image.open("test/images/sample_grey.png") + + assert img_jxl.size == img_png.size + assert img_jxl.mode == img_png.mode == "I;16" + assert not img_jxl.is_animated + assert img_jxl.n_frames == 1 + assert list(img_jxl.getdata()) == list(img_png.getdata()) @pytest.mark.parametrize("image", ["test/images/sample.png", "test/images/sample.jpg"])