Skip to content

Commit

Permalink
Check for duplicate stdin usage on read instead of arg parsing (#10)
Browse files Browse the repository at this point in the history
This change relaxes the check for duplicate usage of `stdin` on arg
declarations (error at runtime if any two args use `MaybeStdin` or
`FileOrStdin`) and instead only check for duplicated `stdin` usage
when these args are accessed.

This allows usages of args that may be mutually exclusive (E.g. under
different subcommands) or use the `global=true` clap option (as reported
in #9)

If a tool happens to accept multiple args that can be Stdin, the CLI user
will only see an error if they actually try to use `stdin` for values
twice.
  • Loading branch information
thepacketgeek authored Jul 6, 2024
1 parent fb6d12f commit 409d461
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 44 deletions.
23 changes: 13 additions & 10 deletions src/file_or_stdin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,21 @@ use super::{Source, StdinError};
/// ```
#[derive(Debug, Clone)]
pub struct FileOrStdin<T = String> {
pub source: Source,
source: Source,
_type: PhantomData<T>,
}

impl<T> FileOrStdin<T> {
/// Was this value read from stdin
pub fn is_stdin(&self) -> bool {
matches!(self.source, Source::Stdin)
}

/// Was this value read from a file (path passed in from argument values)
pub fn is_file(&self) -> bool {
!self.is_stdin()
}

/// Read the entire contents from the input source, returning T::from_str
pub fn contents(self) -> Result<T, StdinError>
where
Expand Down Expand Up @@ -77,15 +87,8 @@ impl<T> FileOrStdin<T> {
/// # Ok(())
/// # }
/// ```
pub fn into_reader(&self) -> Result<impl std::io::Read, StdinError> {
let input: Box<dyn std::io::Read + 'static> = match &self.source {
Source::Stdin => Box::new(std::io::stdin()),
Source::Arg(filepath) => {
let f = std::fs::File::open(filepath)?;
Box::new(f)
}
};
Ok(input)
pub fn into_reader(self) -> Result<impl std::io::Read, StdinError> {
self.source.into_reader()
}

#[cfg(feature = "tokio")]
Expand Down
51 changes: 40 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![doc = include_str!("../README.md")]

use std::io;
use std::io::{self, Read};
use std::str::FromStr;
use std::sync::atomic::AtomicBool;

Expand All @@ -9,11 +9,11 @@ pub use maybe_stdin::MaybeStdin;
mod file_or_stdin;
pub use file_or_stdin::FileOrStdin;

static STDIN_HAS_BEEN_USED: AtomicBool = AtomicBool::new(false);
static STDIN_HAS_BEEN_READ: AtomicBool = AtomicBool::new(false);

#[derive(Debug, thiserror::Error)]
pub enum StdinError {
#[error("stdin argument used more than once")]
#[error("stdin read from more than once")]
StdInRepeatedUse,
#[error(transparent)]
StdIn(#[from] io::Error),
Expand All @@ -23,23 +23,52 @@ pub enum StdinError {

/// Source of the value contents will be either from `stdin` or a CLI arg provided value
#[derive(Clone)]
pub enum Source {
pub(crate) enum Source {
Stdin,
Arg(String),
}

impl Source {
pub(crate) fn into_reader(self) -> Result<impl std::io::Read, StdinError> {
let input: Box<dyn std::io::Read + 'static> = match self {
Source::Stdin => {
if STDIN_HAS_BEEN_READ.load(std::sync::atomic::Ordering::Acquire) {
return Err(StdinError::StdInRepeatedUse);
}
STDIN_HAS_BEEN_READ.store(true, std::sync::atomic::Ordering::SeqCst);
Box::new(std::io::stdin())
}
Source::Arg(filepath) => {
let f = std::fs::File::open(filepath)?;
Box::new(f)
}
};
Ok(input)
}

pub(crate) fn get_value(self) -> Result<String, StdinError> {
match self {
Source::Stdin => {
if STDIN_HAS_BEEN_READ.load(std::sync::atomic::Ordering::Acquire) {
return Err(StdinError::StdInRepeatedUse);
}
STDIN_HAS_BEEN_READ.store(true, std::sync::atomic::Ordering::SeqCst);
let stdin = io::stdin();
let mut input = String::new();
stdin.lock().read_to_string(&mut input)?;
Ok(input)
}
Source::Arg(value) => Ok(value),
}
}
}

impl FromStr for Source {
type Err = StdinError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"-" => {
if STDIN_HAS_BEEN_USED.load(std::sync::atomic::Ordering::Acquire) {
return Err(StdinError::StdInRepeatedUse);
}
STDIN_HAS_BEEN_USED.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(Self::Stdin)
}
"-" => Ok(Self::Stdin),
arg => Ok(Self::Arg(arg.to_owned())),
}
}
Expand Down
19 changes: 3 additions & 16 deletions src/maybe_stdin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::io::{self, Read};
use std::str::FromStr;

use super::{Source, StdinError};
Expand Down Expand Up @@ -27,8 +26,6 @@ use super::{Source, StdinError};
/// ```
#[derive(Clone)]
pub struct MaybeStdin<T> {
/// Source of the contents
pub source: Source,
inner: T,
}

Expand All @@ -41,19 +38,9 @@ where

fn from_str(s: &str) -> Result<Self, Self::Err> {
let source = Source::from_str(s)?;
match &source {
Source::Stdin => {
let stdin = io::stdin();
let mut input = String::new();
stdin.lock().read_to_string(&mut input)?;
Ok(T::from_str(input.trim_end())
.map_err(|e| StdinError::FromStr(format!("{e}")))
.map(|val| Self { source, inner: val })?)
}
Source::Arg(value) => Ok(T::from_str(value)
.map_err(|e| StdinError::FromStr(format!("{e}")))
.map(|val| Self { source, inner: val })?),
}
T::from_str(source.get_value()?.trim())
.map_err(|e| StdinError::FromStr(format!("{e}")))
.map(|val| Self { inner: val })
}
}

Expand Down
7 changes: 4 additions & 3 deletions tests/fixtures/file_or_stdin_positional_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ struct Args {
}

#[cfg(feature = "test_bin")]
fn main() {
fn main() -> Result<(), String> {
let args = Args::parse();
println!(
"FIRST: {}; SECOND: {:?}",
args.first.contents().unwrap(),
args.first.contents().map_err(|e| format!("{e}"))?,
args.second
);
Ok(())
}

#[cfg(feature = "test_bin_tokio")]
Expand All @@ -26,7 +27,7 @@ async fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"FIRST: {}; SECOND: {:?}",
args.first.contents_async().await.unwrap(),
args.first.contents_async().await?,
args.second
);
}
10 changes: 6 additions & 4 deletions tests/fixtures/file_or_stdin_twice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ struct Args {
}

#[cfg(feature = "test_bin")]
fn main() {
fn main() -> Result<(), String> {
let args = Args::parse();
println!(
"FIRST: {}; SECOND: {}",
args.first.contents().unwrap(),
"FIRST: {}; SECOND: {:?}",
args.first.contents().map_err(|e| format!("{e}"))?,
args.second
);

Ok(())
}

#[cfg(feature = "test_bin_tokio")]
Expand All @@ -24,7 +26,7 @@ async fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"FIRST: {}; SECOND: {}",
args.first.contents_async().unwrap(),
args.first.contents_async()?,
args.second
);
}

0 comments on commit 409d461

Please sign in to comment.