Skip to content

Commit

Permalink
Adding options for environment variables. (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Dec 30, 2024
1 parent 9f4d2d6 commit 57c58af
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 20 deletions.
2 changes: 2 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
default = pkgs.mkShell {
buildInputs = with pkgs; [
rustup
pkg-config
openssl
];
};

Expand Down
2 changes: 2 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ pub mod tokio;
#[cfg(feature = "ureq")]
pub mod sync;

const HF_ENDPOINT: &str = "HF_ENDPOINT";

/// This trait is used by users of the lib
/// to implement custom behavior during file downloads
pub trait Progress {
Expand Down
19 changes: 18 additions & 1 deletion src/api/sync.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::RepoInfo;
use super::{RepoInfo, HF_ENDPOINT};
use crate::api::sync::ApiError::InvalidHeader;
use crate::api::Progress;
use crate::{Cache, Repo, RepoType};
Expand Down Expand Up @@ -133,6 +133,23 @@ impl ApiBuilder {
Self::from_cache(cache)
}

/// Creates API with values potentially from environment variables.
/// HF_HOME decides the location of the cache folder
/// HF_ENDPOINT modifies the URL for the huggingface location
/// to download files from.
/// ```
/// use hf_hub::api::sync::ApiBuilder;
/// let api = ApiBuilder::from_env().build().unwrap();
/// ```
pub fn from_env() -> Self {
let cache = Cache::from_env();
let mut builder = Self::from_cache(cache);
if let Ok(endpoint) = std::env::var(HF_ENDPOINT) {
builder = builder.with_endpoint(endpoint);
}
builder
}

/// From a given cache
/// ```
/// use hf_hub::{api::sync::ApiBuilder, Cache};
Expand Down
27 changes: 21 additions & 6 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Progress as SyncProgress;
use super::RepoInfo;
use super::{RepoInfo, HF_ENDPOINT};
use crate::{Cache, Repo, RepoType};
use futures::StreamExt;
use indicatif::ProgressBar;
Expand Down Expand Up @@ -133,6 +133,23 @@ impl ApiBuilder {
Self::from_cache(cache)
}

/// Creates API with values potentially from environment variables.
/// HF_HOME decides the location of the cache folder
/// HF_ENDPOINT modifies the URL for the huggingface location
/// to download files from.
/// ```
/// use hf_hub::api::tokio::ApiBuilder;
/// let api = ApiBuilder::from_env().build().unwrap();
/// ```
pub fn from_env() -> Self {
let cache = Cache::from_env();
let mut builder = Self::from_cache(cache);
if let Ok(endpoint) = std::env::var(HF_ENDPOINT) {
builder = builder.with_endpoint(endpoint);
}
builder
}

/// High CPU download
///
/// This may cause issues on regular desktops as it will saturate
Expand All @@ -141,12 +158,10 @@ impl ApiBuilder {
/// saturate the bandwidth (>500MB/s) better.
/// ```
/// use hf_hub::api::tokio::ApiBuilder;
/// let api = ApiBuilder::high().build().unwrap();
/// let api = ApiBuilder::new().high().build().unwrap();
/// ```
pub fn high() -> Self {
let cache = Cache::default();
Self::from_cache(cache)
.with_max_files(num_cpus::get())
pub fn high(self) -> Self {
self.with_max_files(num_cpus::get())
.with_chunk_size(Some(10_000_000))
}

Expand Down
36 changes: 23 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use std::path::PathBuf;
#[cfg(any(feature = "tokio", feature = "ureq"))]
pub mod api;

const HF_HOME: &str = "HF_HOME";

/// The type of repo to interact with
#[derive(Debug, Clone, Copy)]
pub enum RepoType {
Expand All @@ -37,6 +39,19 @@ impl Cache {
Self { path }
}

/// Creates cache from environment variable HF_HOME (if defined) otherwise
/// defaults to [`home_dir`]/.cache/huggingface/
pub fn from_env() -> Self {
match std::env::var(HF_HOME) {
Ok(home) => {
let mut path: PathBuf = home.into();
path.push("hub");
Self::new(path)
}
Err(_) => Self::default(),
}
}

/// Creates a new cache object location
pub fn path(&self) -> &PathBuf {
&self.path
Expand Down Expand Up @@ -137,6 +152,7 @@ impl CacheRepo {
fn new(cache: Cache, repo: Repo) -> Self {
Self { cache, repo }
}

/// This will get the location of the file within the cache for the remote
/// `filename`. Will return `None` if file is not already present in cache.
pub fn get(&self, filename: &str) -> Option<PathBuf> {
Expand Down Expand Up @@ -197,15 +213,9 @@ impl CacheRepo {

impl Default for Cache {
fn default() -> Self {
let mut path = match std::env::var("HF_HOME") {
Ok(home) => home.into(),
Err(_) => {
let mut cache = dirs::home_dir().expect("Cache directory cannot be found");
cache.push(".cache");
cache.push("huggingface");
cache
}
};
let mut path = dirs::home_dir().expect("Cache directory cannot be found");
path.push(".cache");
path.push("huggingface");
path.push("hub");
Self::new(path)
}
Expand Down Expand Up @@ -338,9 +348,9 @@ mod tests {
#[test]
#[cfg(not(target_os = "windows"))]
fn token_path() {
let cache = Cache::default();
let cache = Cache::from_env();
let token_path = cache.token_path().to_str().unwrap().to_string();
if let Ok(hf_home) = std::env::var("HF_HOME") {
if let Ok(hf_home) = std::env::var(HF_HOME) {
assert_eq!(token_path, format!("{hf_home}/token"));
} else {
let n = "huggingface/token".len();
Expand All @@ -351,9 +361,9 @@ mod tests {
#[test]
#[cfg(target_os = "windows")]
fn token_path() {
let cache = Cache::default();
let cache = Cache::from_env();
let token_path = cache.token_path().to_str().unwrap().to_string();
if let Ok(hf_home) = std::env::var("HF_HOME") {
if let Ok(hf_home) = std::env::var(HF_HOME) {
assert_eq!(token_path, format!("{hf_home}\\token"));
} else {
let n = "huggingface/token".len();
Expand Down

0 comments on commit 57c58af

Please sign in to comment.