diff --git a/src/decode.rs b/src/decode.rs index 27e70d4..56ca34c 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1,10 +1,11 @@ use std::borrow::Cow; +use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use jpegxl_rs::decode::{Data, Metadata, Pixels}; -use jpegxl_rs::decoder_builder; use jpegxl_rs::parallel::threads_runner::ThreadsRunner; +use jpegxl_rs::{decoder_builder, DecodeError}; // it works even if the item is not documented: @@ -86,7 +87,7 @@ impl Decoder { &self, _py: Python, data: &[u8], - ) -> (bool, ImageInfo, Cow<'_, [u8]>, Cow<'_, [u8]>) { + ) -> PyResult<(bool, ImageInfo, Cow<'_, [u8]>, Cow<'_, [u8]>)> { _py.allow_threads(|| self.call_inner(data)) } @@ -96,7 +97,7 @@ impl Decoder { } impl Decoder { - fn call_inner(&self, data: &[u8]) -> (bool, ImageInfo, Cow<'_, [u8]>, Cow<'_, [u8]>) { + fn call_inner(&self, data: &[u8]) -> PyResult<(bool, ImageInfo, Cow<'_, [u8]>, Cow<'_, [u8]>)> { let parallel_runner = ThreadsRunner::new( None, if self.num_threads < 0 { @@ -104,13 +105,14 @@ impl Decoder { } else { Some(self.num_threads as usize) }, - ).unwrap(); + ) + .ok_or_else(|| PyRuntimeError::new_err("Could not create JxlThreadsRunner"))?; let decoder = decoder_builder() .icc_profile(true) .parallel_runner(¶llel_runner) .build() - .unwrap(); - let (info, img) = decoder.reconstruct(&data).unwrap(); + .map_err(to_pyjxlerror)?; + let (info, img) = decoder.reconstruct(&data).map_err(to_pyjxlerror)?; let (jpeg, img) = match img { Data::Jpeg(x) => (true, x), Data::Pixels(x) => (false, convert_pixels(x)), @@ -119,11 +121,15 @@ impl Decoder { Some(x) => x.to_vec(), None => Vec::new(), }; - ( + Ok(( jpeg, ImageInfo::from(info), Cow::Owned(img), Cow::Owned(icc_profile), - ) + )) } } + +fn to_pyjxlerror(e: DecodeError) -> PyErr { + PyRuntimeError::new_err(e.to_string()) +} diff --git a/src/encode.rs b/src/encode.rs index 613feae..70fa771 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -1,10 +1,11 @@ use std::borrow::Cow; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use jpegxl_rs::encode::{ColorEncoding, EncoderFrame, EncoderResult, EncoderSpeed, Metadata}; -use jpegxl_rs::encoder_builder; use jpegxl_rs::parallel::threads_runner::ThreadsRunner; +use jpegxl_rs::{encoder_builder, EncodeError}; #[pyclass(module = "pillow_jxl")] pub struct Encoder { @@ -32,18 +33,26 @@ impl Encoder { use_container: bool, use_original_profile: bool, num_threads: isize, - ) -> Self { + ) -> PyResult { let (num_channels, has_alpha) = match mode { "RGBA" => (4, true), "RGB" => (3, false), "LA" => (2, true), "L" => (1, false), - _ => panic!("Only RGB, RGBA, L, LA are supported."), + _ => { + return Err(PyValueError::new_err( + "Only RGB, RGBA, L, LA are supported.", + )) + } }; let decoding_speed = match decoding_speed { 0..=4 => decoding_speed, - _ => panic!("Decoding speed must be between 0 and 4"), + _ => { + return Err(PyValueError::new_err( + "Decoding speed must be between 0 and 4", + )) + } }; let use_original_profile = match lossless { @@ -51,7 +60,7 @@ impl Encoder { false => use_original_profile, }; - Self { + Ok(Self { num_channels, has_alpha, lossless, @@ -61,7 +70,7 @@ impl Encoder { use_container, use_original_profile, num_threads, - } + }) } #[pyo3(signature = (data, width, height, jpeg_encode, exif=None, jumb=None, xmp=None))] @@ -75,7 +84,7 @@ impl Encoder { exif: Option<&[u8]>, jumb: Option<&[u8]>, xmp: Option<&[u8]>, - ) -> Cow<'_, [u8]> { + ) -> PyResult> { py.allow_threads(|| self.call_inner(data, width, height, jpeg_encode, exif, jumb, xmp)) } @@ -97,7 +106,7 @@ impl Encoder { exif: Option<&[u8]>, jumb: Option<&[u8]>, xmp: Option<&[u8]>, - ) -> Cow<'_, [u8]> { + ) -> PyResult> { let parallel_runner = ThreadsRunner::new( None, if self.num_threads < 0 { @@ -105,7 +114,8 @@ impl Encoder { } else { Some(self.num_threads as usize) }, - ).unwrap(); + ) + .ok_or_else(|| PyRuntimeError::new_err("Could not create JxlThreadsRunner"))?; let mut encoder = encoder_builder() .parallel_runner(¶llel_runner) .jpeg_quality(self.quality) @@ -114,12 +124,12 @@ impl Encoder { .use_container(self.use_container) .decoding_speed(self.decoding_speed) .build() - .unwrap(); + .map_err(to_pyjxlerror)?; encoder.uses_original_profile = self.use_original_profile; encoder.color_encoding = match self.num_channels { 1 | 2 => ColorEncoding::SrgbLuma, 3 | 4 => ColorEncoding::Srgb, - _ => panic!("Invalid num channels"), + _ => return Err(PyValueError::new_err("Invalid num channels")), }; encoder.speed = match self.effort { 1 => EncoderSpeed::Lightning, @@ -131,30 +141,36 @@ impl Encoder { 7 => EncoderSpeed::Squirrel, 8 => EncoderSpeed::Kitten, 9 => EncoderSpeed::Tortoise, - _ => panic!("Invalid effort"), + _ => return Err(PyValueError::new_err("Invalid effort")), }; let buffer: EncoderResult = match jpeg_encode { - true => encoder.encode_jpeg(&data).unwrap(), + true => encoder.encode_jpeg(&data).map_err(to_pyjxlerror)?, false => { let frame = EncoderFrame::new(data).num_channels(self.num_channels); if let Some(exif_data) = exif { encoder .add_metadata(&Metadata::Exif(exif_data), true) - .unwrap(); + .map_err(to_pyjxlerror)? } if let Some(xmp_data) = xmp { encoder .add_metadata(&Metadata::Xmp(xmp_data), true) - .unwrap(); + .map_err(to_pyjxlerror)? } if let Some(jumb_data) = jumb { encoder .add_metadata(&Metadata::Jumb(jumb_data), true) - .unwrap(); + .map_err(to_pyjxlerror)? } - encoder.encode_frame(&frame, width, height).unwrap() + encoder + .encode_frame(&frame, width, height) + .map_err(to_pyjxlerror)? } }; - Cow::Owned(buffer.data) + Ok(Cow::Owned(buffer.data)) } } + +fn to_pyjxlerror(e: EncodeError) -> PyErr { + PyRuntimeError::new_err(e.to_string()) +} diff --git a/src/lib.rs b/src/lib.rs index 4072806..963a9e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,16 @@ -use pyo3::prelude::*; +use pyo3::{create_exception, exceptions::PyRuntimeError, prelude::*}; // it works even if the item is not documented: mod decode; mod encode; +create_exception!(my_module, JxlException, PyRuntimeError, "Jxl Error"); + #[pymodule] #[pyo3(name = "pillow_jxl")] fn pillow_jxl(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add("JxlException", m.py().get_type_bound::())?; Ok(()) }