diff --git a/README.md b/README.md index 6b05163..331785c 100644 --- a/README.md +++ b/README.md @@ -204,11 +204,20 @@ fn main() { ```rust use std::collections::HashMap; +use std::collections::HashMap; +use reqwest::header::HeaderMap; fn main() { - let mut reader = oneio::get_remote_reader( + let headers: HeaderMap = (&HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])) + .try_into().expect("invalid headers"); + + let client = reqwest::blocking::Client::builder() + .default_headers(headers) + .danger_accept_invalid_certs(true) + .build().unwrap(); + let mut reader = oneio::get_http_reader( "https://SOME_REMOTE_RESOURCE_PROTECTED_BY_ACCESS_TOKEN", - HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())]) + Some(client), ).unwrap(); let mut text = "".to_string(); reader.read_to_string(&mut text).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 3c60ce1..d3c1e46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,9 +83,14 @@ //! Read remote content with custom headers //! ```no_run //! use std::collections::HashMap; -//! let mut reader = oneio::get_remote_reader( +//! use reqwest::header::HeaderMap; +//! let headers: HeaderMap = (&HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])).try_into().expect("invalid headers"); +//! let client = reqwest::blocking::Client::builder() +//! .default_headers(headers) +//! .build().unwrap(); +//! let mut reader = oneio::get_http_reader( //! "https://SOME_REMOTE_RESOURCE_PROTECTED_BY_ACCESS_TOKEN", -//! HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())]) +//! Some(client), //! ).unwrap(); //! let mut text = "".to_string(); //! reader.read_to_string(&mut text).unwrap(); diff --git a/src/oneio/mod.rs b/src/oneio/mod.rs index a69e94b..94c5d53 100644 --- a/src/oneio/mod.rs +++ b/src/oneio/mod.rs @@ -18,7 +18,7 @@ use std::fs::File; use std::io::{BufWriter, Read, Write}; use std::path::Path; -fn get_writer_raw(path: &str) -> Result, OneIoError> { +pub fn get_writer_raw(path: &str) -> Result, OneIoError> { let path = Path::new(path); if let Some(prefix) = path.parent() { std::fs::create_dir_all(prefix)?; @@ -27,7 +27,7 @@ fn get_writer_raw(path: &str) -> Result, OneIoError> { Ok(output_file) } -fn get_reader_raw(path: &str) -> Result, OneIoError> { +pub fn get_reader_raw(path: &str) -> Result, OneIoError> { #[cfg(feature = "remote")] let raw_reader: Box = remote::get_reader_raw_remote(path)?; #[cfg(not(feature = "remote"))] diff --git a/src/oneio/remote.rs b/src/oneio/remote.rs index 0d7d075..ee684d1 100644 --- a/src/oneio/remote.rs +++ b/src/oneio/remote.rs @@ -1,7 +1,7 @@ use crate::oneio::compressions::OneIOCompression; use crate::oneio::{compressions, get_writer_raw}; use crate::OneIoError; -use std::collections::HashMap; +use reqwest::blocking::Client; use std::io::Read; fn get_protocol(path: &str) -> Option { @@ -12,7 +12,7 @@ fn get_protocol(path: &str) -> Option { Some(parts[0].to_string()) } -fn get_remote_ftp_raw(path: &str) -> Result, OneIoError> { +fn get_ftp_reader_raw(path: &str) -> Result, OneIoError> { if !path.starts_with("ftp://") { return Err(OneIoError::NotSupported(path.to_string())); } @@ -31,51 +31,72 @@ fn get_remote_ftp_raw(path: &str) -> Result, OneIoError> { Ok(reader) } -fn get_remote_http_raw( +fn get_http_reader_raw( path: &str, - header: HashMap, + opt_client: Option, ) -> Result { - let mut headers: reqwest::header::HeaderMap = (&header).try_into().expect("invalid headers"); - headers.insert( - reqwest::header::USER_AGENT, - reqwest::header::HeaderValue::from_static("oneio"), - ); - headers.insert( - reqwest::header::CONTENT_LENGTH, - reqwest::header::HeaderValue::from_static("0"), - ); - #[cfg(feature = "cli")] - headers.insert( - reqwest::header::CACHE_CONTROL, - reqwest::header::HeaderValue::from_static("no-cache"), - ); - let client = reqwest::blocking::Client::builder() - .default_headers(headers) - .build()?; + let client = match opt_client { + Some(c) => c, + None => { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::USER_AGENT, + reqwest::header::HeaderValue::from_static("oneio"), + ); + headers.insert( + reqwest::header::CONTENT_LENGTH, + reqwest::header::HeaderValue::from_static("0"), + ); + #[cfg(feature = "cli")] + headers.insert( + reqwest::header::CACHE_CONTROL, + reqwest::header::HeaderValue::from_static("no-cache"), + ); + Client::builder().default_headers(headers).build()? + } + }; let res = client .execute(client.get(path).build()?)? .error_for_status()?; Ok(res) } -/// Get a reader for remote content with the capability to specify headers. +/// Get a reader for remote content with the capability to specify headers, and customer reqwest options. /// -/// Example usage: +/// Example usage with custom header fields: /// ```no_run /// use std::collections::HashMap; -/// let mut reader = oneio::get_remote_reader( +/// use reqwest::header::HeaderMap; +/// let headers: HeaderMap = (&HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])).try_into().expect("invalid headers"); +/// let client = reqwest::blocking::Client::builder() +/// .default_headers(headers) +/// .build().unwrap(); +/// let mut reader = oneio::get_http_reader( /// "https://SOME_REMOTE_RESOURCE_PROTECTED_BY_ACCESS_TOKEN", -/// HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())]) +/// Some(client), +/// ).unwrap(); +/// let mut text = "".to_string(); +/// reader.read_to_string(&mut text).unwrap(); +/// println!("{}", text); +/// ``` +/// +/// Example with customer builder that allows invalid certificates (bad practice): +/// ```no_run +/// use std::collections::HashMap; +/// let client = reqwest::blocking::ClientBuilder::new().danger_accept_invalid_certs(true).build().unwrap(); +/// let mut reader = oneio::get_http_reader( +/// "https://example.com", +/// Some(client) /// ).unwrap(); /// let mut text = "".to_string(); /// reader.read_to_string(&mut text).unwrap(); /// println!("{}", text); /// ``` -pub fn get_remote_reader( +pub fn get_http_reader( path: &str, - header: HashMap, + opt_client: Option, ) -> Result, OneIoError> { - let raw_reader: Box = Box::new(get_remote_http_raw(path, header)?); + let raw_reader: Box = Box::new(get_http_reader_raw(path, opt_client)?); let file_type = *path.split('.').collect::>().last().unwrap(); match file_type { #[cfg(feature = "gz")] @@ -117,9 +138,7 @@ pub fn get_remote_reader( /// fn main() -> Result<(), OneIoError> { /// let remote_path = "https://example.com/file.txt"; /// let local_path = "path/to/save/file.txt"; -/// let header: Option> = None; -/// -/// download(remote_path, local_path, header)?; +/// download(remote_path, local_path, None)?; /// /// Ok(()) /// } @@ -127,7 +146,7 @@ pub fn get_remote_reader( pub fn download( remote_path: &str, local_path: &str, - header: Option>, + opt_client: Option, ) -> Result<(), OneIoError> { match get_protocol(remote_path) { None => { @@ -136,12 +155,12 @@ pub fn download( Some(protocol) => match protocol.as_str() { "http" | "https" => { let mut writer = get_writer_raw(local_path)?; - let mut response = get_remote_http_raw(remote_path, header.unwrap_or_default())?; + let mut response = get_http_reader_raw(remote_path, opt_client)?; response.copy_to(&mut writer)?; } "ftp" => { let mut writer = get_writer_raw(local_path)?; - let mut reader = get_remote_ftp_raw(remote_path)?; + let mut reader = get_ftp_reader_raw(remote_path)?; std::io::copy(&mut reader, &mut writer)?; } #[cfg(feature = "s3")] @@ -179,7 +198,7 @@ pub fn download( /// let local_path = "/path/to/save/file.txt"; /// let retry = 3; /// -/// match download_with_retry(remote_path, local_path, None, retry) { +/// match download_with_retry(remote_path, local_path, retry, None) { /// Ok(_) => println!("File downloaded successfully"), /// Err(e) => eprintln!("Error downloading file: {:?}", e), /// } @@ -187,12 +206,12 @@ pub fn download( pub fn download_with_retry( remote_path: &str, local_path: &str, - header: Option>, retry: usize, + opt_client: Option, ) -> Result<(), OneIoError> { let mut retry = retry; loop { - match download(remote_path, local_path, header.clone()) { + match download(remote_path, local_path, opt_client.clone()) { Ok(_) => { return Ok(()); } @@ -212,11 +231,11 @@ pub(crate) fn get_reader_raw_remote(path: &str) -> Result, let raw_reader: Box = match get_protocol(path) { Some(protocol) => match protocol.as_str() { "http" | "https" => { - let response = get_remote_http_raw(path, HashMap::new())?; + let response = get_http_reader_raw(path, None)?; Box::new(response) } "ftp" => { - let response = get_remote_ftp_raw(path)?; + let response = get_ftp_reader_raw(path)?; Box::new(response) } #[cfg(feature = "s3")]