From b6032a41c83e74c911b3f5549305a5d6a30bb660 Mon Sep 17 00:00:00 2001 From: Leigh McCulloch <351529+leighmcculloch@users.noreply.github.com> Date: Tue, 16 Jul 2024 02:46:45 +1000 Subject: [PATCH] Skip whitespace in base64 encoded inputs given to the CLI (#377) * Skip whitespace in base64 encoded inputs given to the CLI * guess --- src/cli/decode.rs | 20 +++++++---- src/cli/guess.rs | 45 +++++++++++++++--------- src/cli/mod.rs | 1 + src/cli/skip_whitespace.rs | 70 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 24 deletions(-) create mode 100644 src/cli/skip_whitespace.rs diff --git a/src/cli/decode.rs b/src/cli/decode.rs index 8eabf0ff..d0b4d741 100644 --- a/src/cli/decode.rs +++ b/src/cli/decode.rs @@ -9,7 +9,7 @@ use std::{ use clap::{Args, ValueEnum}; use serde::Serialize; -use crate::cli::Channel; +use crate::cli::{skip_whitespace::SkipWhitespace, Channel}; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -82,28 +82,34 @@ macro_rules! run_x { Error::UnknownType(self.r#type.clone(), &crate::$m::TypeVariant::VARIANTS_STR) })?; for f in &mut files { - let mut f = crate::$m::Limited::new(f, crate::$m::Limits::none()); match self.input { InputFormat::Single => { - let t = crate::$m::Type::read_xdr_to_end(r#type, &mut f)?; + let mut l = crate::$m::Limited::new(f, crate::$m::Limits::none()); + let t = crate::$m::Type::read_xdr_to_end(r#type, &mut l)?; self.out(&t)?; } InputFormat::SingleBase64 => { - let t = crate::$m::Type::read_xdr_base64_to_end(r#type, &mut f)?; + let sw = SkipWhitespace::new(f); + let mut l = crate::$m::Limited::new(sw, crate::$m::Limits::none()); + let t = crate::$m::Type::read_xdr_base64_to_end(r#type, &mut l)?; self.out(&t)?; } InputFormat::Stream => { - for t in crate::$m::Type::read_xdr_iter(r#type, &mut f) { + let mut l = crate::$m::Limited::new(f, crate::$m::Limits::none()); + for t in crate::$m::Type::read_xdr_iter(r#type, &mut l) { self.out(&t?)?; } } InputFormat::StreamBase64 => { - for t in crate::$m::Type::read_xdr_base64_iter(r#type, &mut f) { + let sw = SkipWhitespace::new(f); + let mut l = crate::$m::Limited::new(sw, crate::$m::Limits::none()); + for t in crate::$m::Type::read_xdr_base64_iter(r#type, &mut l) { self.out(&t?)?; } } InputFormat::StreamFramed => { - for t in crate::$m::Type::read_xdr_framed_iter(r#type, &mut f) { + let mut l = crate::$m::Limited::new(f, crate::$m::Limits::none()); + for t in crate::$m::Type::read_xdr_framed_iter(r#type, &mut l) { self.out(&t?)?; } } diff --git a/src/cli/guess.rs b/src/cli/guess.rs index 5212fc04..4623910f 100644 --- a/src/cli/guess.rs +++ b/src/cli/guess.rs @@ -7,7 +7,7 @@ use std::{ use clap::{Args, ValueEnum}; -use crate::cli::Channel; +use crate::cli::{skip_whitespace::SkipWhitespace, Channel}; #[derive(thiserror::Error, Debug)] #[allow(clippy::enum_variant_names)] @@ -69,21 +69,29 @@ impl Default for OutputFormat { macro_rules! run_x { ($f:ident, $m:ident) => { fn $f(&self) -> Result<(), Error> { - let mut f = - crate::$m::Limited::new(ResetRead::new(self.file()?), crate::$m::Limits::none()); + let mut rr = ResetRead::new(self.file()?); 'variants: for v in crate::$m::TypeVariant::VARIANTS { - f.inner.reset(); + rr.reset(); let count: usize = match self.input { - InputFormat::Single => crate::$m::Type::read_xdr_to_end(v, &mut f) - .ok() - .map(|_| 1) - .unwrap_or_default(), - InputFormat::SingleBase64 => crate::$m::Type::read_xdr_base64_to_end(v, &mut f) - .ok() - .map(|_| 1) - .unwrap_or_default(), + InputFormat::Single => { + let mut l = crate::$m::Limited::new(&mut rr, crate::$m::Limits::none()); + crate::$m::Type::read_xdr_to_end(v, &mut l) + .ok() + .map(|_| 1) + .unwrap_or_default() + } + InputFormat::SingleBase64 => { + let sw = SkipWhitespace::new(&mut rr); + let mut l = crate::$m::Limited::new(sw, crate::$m::Limits::none()); + crate::$m::Type::read_xdr_base64_to_end(v, &mut l) + .ok() + .map(|_| 1) + .unwrap_or_default() + } InputFormat::Stream => { - let iter = crate::$m::Type::read_xdr_iter(v, &mut f).take(self.certainty); + let mut l = crate::$m::Limited::new(&mut rr, crate::$m::Limits::none()); + let iter = crate::$m::Type::read_xdr_iter(v, &mut l); + let iter = iter.take(self.certainty); let mut count = 0; for v in iter { match v { @@ -94,8 +102,10 @@ macro_rules! run_x { count } InputFormat::StreamBase64 => { - let iter = - crate::$m::Type::read_xdr_base64_iter(v, &mut f).take(self.certainty); + let sw = SkipWhitespace::new(&mut rr); + let mut l = crate::$m::Limited::new(sw, crate::$m::Limits::none()); + let iter = crate::$m::Type::read_xdr_base64_iter(v, &mut l); + let iter = iter.take(self.certainty); let mut count = 0; for v in iter { match v { @@ -106,8 +116,9 @@ macro_rules! run_x { count } InputFormat::StreamFramed => { - let iter = - crate::$m::Type::read_xdr_framed_iter(v, &mut f).take(self.certainty); + let mut l = crate::$m::Limited::new(&mut rr, crate::$m::Limits::none()); + let iter = crate::$m::Type::read_xdr_framed_iter(v, &mut l); + let iter = iter.take(self.certainty); let mut count = 0; for v in iter { match v { diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 68dd0fb6..6128be04 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -1,6 +1,7 @@ pub mod decode; pub mod encode; pub mod guess; +mod skip_whitespace; pub mod types; mod version; diff --git a/src/cli/skip_whitespace.rs b/src/cli/skip_whitespace.rs new file mode 100644 index 00000000..bb2cadb1 --- /dev/null +++ b/src/cli/skip_whitespace.rs @@ -0,0 +1,70 @@ +use std::io::Read; + +/// Forwards read operations to the wrapped object, skipping over any +/// whitespace. +pub struct SkipWhitespace { + pub inner: R, +} + +impl SkipWhitespace { + pub fn new(inner: R) -> Self { + SkipWhitespace { inner } + } +} + +impl Read for SkipWhitespace { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let n = self.inner.read(buf)?; + + let mut written = 0; + for read in 0..n { + if !buf[read].is_ascii_whitespace() { + buf[written] = buf[read]; + written += 1; + } + } + + Ok(written) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test() { + struct Test { + input: &'static [u8], + output: &'static [u8], + } + let tests = [ + Test { + input: b"", + output: b"", + }, + Test { + input: b" \n\t\r", + output: b"", + }, + Test { + input: b"a c", + output: b"ac", + }, + Test { + input: b"ab cd", + output: b"abcd", + }, + Test { + input: b" ab \n cd ", + output: b"abcd", + }, + ]; + for (i, t) in tests.iter().enumerate() { + let mut skip = SkipWhitespace::new(t.input); + let mut output = Vec::new(); + skip.read_to_end(&mut output).unwrap(); + assert_eq!(output, t.output, "#{i}"); + } + } +}