Skip to content

Commit

Permalink
Merge pull request #53 from bgpkit/feature/custom-client
Browse files Browse the repository at this point in the history
[feature] allow specifying custom reqwest client
  • Loading branch information
digizeph authored Jun 4, 2024
2 parents 1dbffb0 + ccd9fcb commit 88714ee
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 45 deletions.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,20 @@ fn main() {

```rust
use std::collections::HashMap;
use std::collections::HashMap;
use reqwest::header::HeaderMap;

fn main() {
let mut reader = oneio::get_remote_reader(
let headers: HeaderMap = (&HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())]))
.try_into().expect("invalid headers");

let client = reqwest::blocking::Client::builder()
.default_headers(headers)
.danger_accept_invalid_certs(true)
.build().unwrap();
let mut reader = oneio::get_http_reader(
"https://SOME_REMOTE_RESOURCE_PROTECTED_BY_ACCESS_TOKEN",
HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])
Some(client),
).unwrap();
let mut text = "".to_string();
reader.read_to_string(&mut text).unwrap();
Expand Down
9 changes: 7 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,14 @@
//! Read remote content with custom headers
//! ```no_run
//! use std::collections::HashMap;
//! let mut reader = oneio::get_remote_reader(
//! use reqwest::header::HeaderMap;
//! let headers: HeaderMap = (&HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])).try_into().expect("invalid headers");
//! let client = reqwest::blocking::Client::builder()
//! .default_headers(headers)
//! .build().unwrap();
//! let mut reader = oneio::get_http_reader(
//! "https://SOME_REMOTE_RESOURCE_PROTECTED_BY_ACCESS_TOKEN",
//! HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])
//! Some(client),
//! ).unwrap();
//! let mut text = "".to_string();
//! reader.read_to_string(&mut text).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/oneio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::fs::File;
use std::io::{BufWriter, Read, Write};
use std::path::Path;

fn get_writer_raw(path: &str) -> Result<BufWriter<File>, OneIoError> {
pub fn get_writer_raw(path: &str) -> Result<BufWriter<File>, OneIoError> {
let path = Path::new(path);
if let Some(prefix) = path.parent() {
std::fs::create_dir_all(prefix)?;
Expand All @@ -27,7 +27,7 @@ fn get_writer_raw(path: &str) -> Result<BufWriter<File>, OneIoError> {
Ok(output_file)
}

fn get_reader_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
pub fn get_reader_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
#[cfg(feature = "remote")]
let raw_reader: Box<dyn Read + Send> = remote::get_reader_raw_remote(path)?;
#[cfg(not(feature = "remote"))]
Expand Down
97 changes: 58 additions & 39 deletions src/oneio/remote.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::oneio::compressions::OneIOCompression;
use crate::oneio::{compressions, get_writer_raw};
use crate::OneIoError;
use std::collections::HashMap;
use reqwest::blocking::Client;
use std::io::Read;

fn get_protocol(path: &str) -> Option<String> {
Expand All @@ -12,7 +12,7 @@ fn get_protocol(path: &str) -> Option<String> {
Some(parts[0].to_string())
}

fn get_remote_ftp_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
fn get_ftp_reader_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
if !path.starts_with("ftp://") {
return Err(OneIoError::NotSupported(path.to_string()));
}
Expand All @@ -31,51 +31,72 @@ fn get_remote_ftp_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
Ok(reader)
}

fn get_remote_http_raw(
fn get_http_reader_raw(
path: &str,
header: HashMap<String, String>,
opt_client: Option<Client>,
) -> Result<reqwest::blocking::Response, OneIoError> {
let mut headers: reqwest::header::HeaderMap = (&header).try_into().expect("invalid headers");
headers.insert(
reqwest::header::USER_AGENT,
reqwest::header::HeaderValue::from_static("oneio"),
);
headers.insert(
reqwest::header::CONTENT_LENGTH,
reqwest::header::HeaderValue::from_static("0"),
);
#[cfg(feature = "cli")]
headers.insert(
reqwest::header::CACHE_CONTROL,
reqwest::header::HeaderValue::from_static("no-cache"),
);
let client = reqwest::blocking::Client::builder()
.default_headers(headers)
.build()?;
let client = match opt_client {
Some(c) => c,
None => {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
reqwest::header::HeaderValue::from_static("oneio"),
);
headers.insert(
reqwest::header::CONTENT_LENGTH,
reqwest::header::HeaderValue::from_static("0"),
);
#[cfg(feature = "cli")]
headers.insert(
reqwest::header::CACHE_CONTROL,
reqwest::header::HeaderValue::from_static("no-cache"),
);
Client::builder().default_headers(headers).build()?
}
};
let res = client
.execute(client.get(path).build()?)?
.error_for_status()?;
Ok(res)
}

/// Get a reader for remote content with the capability to specify headers.
/// Get a reader for remote content with the capability to specify headers, and customer reqwest options.
///
/// Example usage:
/// Example usage with custom header fields:
/// ```no_run
/// use std::collections::HashMap;
/// let mut reader = oneio::get_remote_reader(
/// use reqwest::header::HeaderMap;
/// let headers: HeaderMap = (&HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])).try_into().expect("invalid headers");
/// let client = reqwest::blocking::Client::builder()
/// .default_headers(headers)
/// .build().unwrap();
/// let mut reader = oneio::get_http_reader(
/// "https://SOME_REMOTE_RESOURCE_PROTECTED_BY_ACCESS_TOKEN",
/// HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])
/// Some(client),
/// ).unwrap();
/// let mut text = "".to_string();
/// reader.read_to_string(&mut text).unwrap();
/// println!("{}", text);
/// ```
///
/// Example with customer builder that allows invalid certificates (bad practice):
/// ```no_run
/// use std::collections::HashMap;
/// let client = reqwest::blocking::ClientBuilder::new().danger_accept_invalid_certs(true).build().unwrap();
/// let mut reader = oneio::get_http_reader(
/// "https://example.com",
/// Some(client)
/// ).unwrap();
/// let mut text = "".to_string();
/// reader.read_to_string(&mut text).unwrap();
/// println!("{}", text);
/// ```
pub fn get_remote_reader(
pub fn get_http_reader(
path: &str,
header: HashMap<String, String>,
opt_client: Option<Client>,
) -> Result<Box<dyn Read + Send>, OneIoError> {
let raw_reader: Box<dyn Read + Send> = Box::new(get_remote_http_raw(path, header)?);
let raw_reader: Box<dyn Read + Send> = Box::new(get_http_reader_raw(path, opt_client)?);
let file_type = *path.split('.').collect::<Vec<&str>>().last().unwrap();
match file_type {
#[cfg(feature = "gz")]
Expand Down Expand Up @@ -117,17 +138,15 @@ pub fn get_remote_reader(
/// fn main() -> Result<(), OneIoError> {
/// let remote_path = "https://example.com/file.txt";
/// let local_path = "path/to/save/file.txt";
/// let header: Option<HashMap<String, String>> = None;
///
/// download(remote_path, local_path, header)?;
/// download(remote_path, local_path, None)?;
///
/// Ok(())
/// }
/// ```
pub fn download(
remote_path: &str,
local_path: &str,
header: Option<HashMap<String, String>>,
opt_client: Option<Client>,
) -> Result<(), OneIoError> {
match get_protocol(remote_path) {
None => {
Expand All @@ -136,12 +155,12 @@ pub fn download(
Some(protocol) => match protocol.as_str() {
"http" | "https" => {
let mut writer = get_writer_raw(local_path)?;
let mut response = get_remote_http_raw(remote_path, header.unwrap_or_default())?;
let mut response = get_http_reader_raw(remote_path, opt_client)?;
response.copy_to(&mut writer)?;
}
"ftp" => {
let mut writer = get_writer_raw(local_path)?;
let mut reader = get_remote_ftp_raw(remote_path)?;
let mut reader = get_ftp_reader_raw(remote_path)?;
std::io::copy(&mut reader, &mut writer)?;
}
#[cfg(feature = "s3")]
Expand Down Expand Up @@ -179,20 +198,20 @@ pub fn download(
/// let local_path = "/path/to/save/file.txt";
/// let retry = 3;
///
/// match download_with_retry(remote_path, local_path, None, retry) {
/// match download_with_retry(remote_path, local_path, retry, None) {
/// Ok(_) => println!("File downloaded successfully"),
/// Err(e) => eprintln!("Error downloading file: {:?}", e),
/// }
/// ```
pub fn download_with_retry(
remote_path: &str,
local_path: &str,
header: Option<HashMap<String, String>>,
retry: usize,
opt_client: Option<Client>,
) -> Result<(), OneIoError> {
let mut retry = retry;
loop {
match download(remote_path, local_path, header.clone()) {
match download(remote_path, local_path, opt_client.clone()) {
Ok(_) => {
return Ok(());
}
Expand All @@ -212,11 +231,11 @@ pub(crate) fn get_reader_raw_remote(path: &str) -> Result<Box<dyn Read + Send>,
let raw_reader: Box<dyn Read + Send> = match get_protocol(path) {
Some(protocol) => match protocol.as_str() {
"http" | "https" => {
let response = get_remote_http_raw(path, HashMap::new())?;
let response = get_http_reader_raw(path, None)?;
Box::new(response)
}
"ftp" => {
let response = get_remote_ftp_raw(path)?;
let response = get_ftp_reader_raw(path)?;
Box::new(response)
}
#[cfg(feature = "s3")]
Expand Down

0 comments on commit 88714ee

Please sign in to comment.