From 97263b107767f8907b69cc5dcaa94804704c83ac Mon Sep 17 00:00:00 2001 From: "J. Sebastian Paez" Date: Fri, 1 Nov 2024 04:54:37 -0700 Subject: [PATCH] (refactor) moved format reader dispatch --- crates/sage-cli/src/runner.rs | 29 +++----------- crates/sage-cloudpath/src/util.rs | 65 +++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 23 deletions(-) diff --git a/crates/sage-cli/src/runner.rs b/crates/sage-cli/src/runner.rs index 4903ff9..7e16150 100644 --- a/crates/sage-cli/src/runner.rs +++ b/crates/sage-cli/src/runner.rs @@ -161,34 +161,17 @@ impl Runner { min_deisotope_mz.unwrap_or(0.0), ); - let bruker_extensions = [".d", ".tdf", ".tdf_bin", "ms2", "raw"]; let spectra = chunk .par_iter() .enumerate() .flat_map(|(idx, path)| { let file_id = chunk_idx * batch_size + idx; - - let path_lower = path.to_lowercase(); - let res = if path_lower.ends_with(".mgf.gz") || path_lower.ends_with(".mgf") { - sage_cloudpath::util::read_mgf(path, file_id) - } else if bruker_extensions.iter().any(|ext| { - if path_lower.ends_with(std::path::MAIN_SEPARATOR) { - path_lower - .strip_suffix(std::path::MAIN_SEPARATOR) - .unwrap() - .ends_with(ext) - } else { - path_lower.ends_with(ext) - } - }) { - sage_cloudpath::util::read_tdf( - path, - file_id, - self.parameters.bruker_spectrum_processor, - ) - } else { - sage_cloudpath::util::read_mzml(path, file_id, sn) - }; + let res = sage_cloudpath::util::read_spectra( + path, + file_id, + sn, + self.parameters.bruker_spectrum_processor, + ); match res { Ok(s) => { diff --git a/crates/sage-cloudpath/src/util.rs b/crates/sage-cloudpath/src/util.rs index c3f18b8..3bbc500 100644 --- a/crates/sage-cloudpath/src/util.rs +++ b/crates/sage-cloudpath/src/util.rs @@ -3,6 +3,55 @@ use sage_core::spectrum::RawSpectrum; use serde::Serialize; use tokio::io::AsyncReadExt; +#[derive(Debug, PartialEq, Eq)] +enum FileFormat { + MzML, + MGF, + TDF, + Unidentified, +} + +const BRUKER_EXTENSIONS: [&str; 5] = [".d", ".tdf", ".tdf_bin", "ms2", "raw"]; + +fn is_bruker(path: &str) -> bool { + BRUKER_EXTENSIONS.iter().any(|ext| { + if path.ends_with(std::path::MAIN_SEPARATOR) { + path.strip_suffix(std::path::MAIN_SEPARATOR) + .unwrap() + .ends_with(ext) + } else { + path.ends_with(ext) + } + }) +} + +fn identify_format(s: &str) -> FileFormat { + let path_lower = s.to_lowercase(); + if path_lower.ends_with(".mgf.gz") || path_lower.ends_with(".mgf") { + FileFormat::MGF + } else if is_bruker(&path_lower) { + FileFormat::TDF + } else if path_lower.ends_with(".mzml.gz") || path_lower.ends_with(".mzml") { + FileFormat::MzML + } else { + FileFormat::Unidentified + } +} + +pub fn read_spectra>( + path: S, + file_id: usize, + sn: Option, + bruker_processor: BrukerSpectrumProcessor, +) -> Result, Error> { + match identify_format(path.as_ref()) { + FileFormat::MzML => read_mzml(path, file_id, sn), + FileFormat::MGF => read_mgf(path, file_id), + FileFormat::TDF => read_tdf(path, file_id, bruker_processor), + FileFormat::Unidentified => panic!("Unable to get type for '{}'", path.as_ref()), // read_mzml(path, file_id, sn), + } +} + pub fn read_mzml>( s: S, file_id: usize, @@ -91,3 +140,19 @@ where Ok(()) }) } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_identify_format() { + assert_eq!(identify_format("foo.mzml"), FileFormat::MzML); + assert_eq!(identify_format("foo.mzML"), FileFormat::MzML); + assert_eq!(identify_format("foo.mgf"), FileFormat::MGF); + assert_eq!(identify_format("foo.mgf.gz"), FileFormat::MGF); + assert_eq!(identify_format("foo.tdf"), FileFormat::TDF); + assert_eq!(identify_format("./tomato/foo.d"), FileFormat::TDF); + assert_eq!(identify_format("./tomato/foo.d/"), FileFormat::TDF); + } +}