diff --git a/src/decoder.rs b/src/decoder.rs index cda7fc7..afafec4 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -1,7 +1,7 @@ // Copyright (c) Team CharLS. // SPDX-License-Identifier: BSD-3-Clause -use std::io::{Read, self}; +use std::io::{Read}; #[warn(unused_variables)] @@ -19,7 +19,7 @@ pub struct Decoder { impl Decoder { - pub fn new(mut r: R) -> Decoder { + pub fn new(r: R) -> Decoder { let width = 0; let height = 0; let bits_per_sample = 0; diff --git a/src/jpeg_marker_code.rs b/src/jpeg_marker_code.rs index 08218f5..6e4602d 100644 --- a/src/jpeg_marker_code.rs +++ b/src/jpeg_marker_code.rs @@ -1,6 +1,7 @@ // Copyright (c) Team CharLS. // SPDX-License-Identifier: BSD-3-Clause +use std::convert::TryFrom; #[derive(Debug, Eq, PartialEq)] pub enum JpegMarkerCode { @@ -46,3 +47,16 @@ pub enum JpegMarkerCode { ApplicationData15 = 0xEF, // APP15: Application data 15. Comment = 0xFE // COM: Comment block. } + +impl TryFrom for JpegMarkerCode { + type Error = (); + + fn try_from(v: u8) -> Result { + match v { + x if x == JpegMarkerCode::StartOfImage as u8 => Ok(JpegMarkerCode::StartOfImage), + x if x == JpegMarkerCode::EndOfImage as u8 => Ok(JpegMarkerCode::EndOfImage), + x if x == JpegMarkerCode::StartOfScan as u8 => Ok(JpegMarkerCode::StartOfScan), + _ => Err(()), + } + } +} diff --git a/src/jpeg_stream_reader.rs b/src/jpeg_stream_reader.rs index 5327ae7..84eeeeb 100644 --- a/src/jpeg_stream_reader.rs +++ b/src/jpeg_stream_reader.rs @@ -3,7 +3,7 @@ //mod jpeg_marker_code; -use std::io::{Read, self}; +use std::io::Read; use crate::jpeg_marker_code::JpegMarkerCode; use crate::decoding_error::DecodingError; @@ -40,7 +40,7 @@ pub struct JpegStreamReader { impl JpegStreamReader { - pub fn new(mut r: R) -> JpegStreamReader { + pub fn new(r: R) -> JpegStreamReader { let width = 0; let height = 0; let bits_per_sample = 0; @@ -58,13 +58,28 @@ impl JpegStreamReader { } } - pub fn read_next_marker_code(&mut self) -> JpegMarkerCode { - JpegMarkerCode::StartOfImage + pub fn read_next_marker_code(&mut self) -> Result { + let mut value = self.read_u8()?; + if value != 255 { + return Err(DecodingError::StartOfImageMarkerNotFound); + } + + // Read all preceding 0xFF fill values until a non 0xFF value has been found. (see ISO/IEC 10918-1, B.1.1.2) + while value == 255 { + value = self.read_u8()?; + } + + let r = JpegMarkerCode::try_from(value); + if r.is_err() { + return Err(DecodingError::StartOfImageMarkerNotFound); + } + + return Ok(r.unwrap()) } pub fn read_header(&mut self) -> Result<(), DecodingError> { if self.state == ReaderState::BeforeStartOfImage { - if self.read_next_marker_code() != JpegMarkerCode::StartOfImage { + if self.read_next_marker_code()? != JpegMarkerCode::StartOfImage { return Err(DecodingError::StartOfImageMarkerNotFound); } @@ -73,5 +88,29 @@ impl JpegStreamReader { Ok(()) } + + fn read_u8(&mut self) -> Result { + let mut buf = [0; 1]; + let result = self.reader.read_exact(&mut buf); + if result.is_err() { + return Err(DecodingError::UnknownError); + } + + Ok(buf[0]) + } } +#[cfg(test)] +mod tests { + use std::io::Write; + use super::*; + + #[test] + fn read_header_from_too_small_input_buffer_fails() { + let mut buffer = Vec::new(); + buffer.write_all(&[1]).unwrap(); + + let mut reader = JpegStreamReader::new(buffer.as_slice()); + assert!(reader.read_header().is_err()); + } +}