diff --git a/src/api/mod.rs b/src/api/mod.rs index 121f545..c8ea97f 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,6 @@ -use indicatif::{ProgressBar, ProgressStyle}; +use std::{collections::VecDeque, time::Duration}; + +use indicatif::{style::ProgressTracker, HumanBytes, ProgressBar, ProgressStyle}; use serde::Deserialize; /// The asynchronous version of the API @@ -35,9 +37,9 @@ impl Progress for ProgressBar { self.set_length(size as u64); self.set_style( ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), + "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec_smoothed} ({eta})", + ).unwrap().with_key("bytes_per_sec_smoothed", MovingAvgRate::default()) + , ); let maxlength = 30; let message = if filename.len() > maxlength { @@ -73,3 +75,48 @@ pub struct RepoInfo { /// The commit sha of the repo. pub sha: String, } + +#[derive(Clone, Default)] +struct MovingAvgRate { + samples: VecDeque<(std::time::Instant, u64)>, +} + +impl ProgressTracker for MovingAvgRate { + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn tick(&mut self, state: &indicatif::ProgressState, now: std::time::Instant) { + // sample at most every 20ms + if self + .samples + .back() + .map_or(true, |(prev, _)| (now - *prev) > Duration::from_millis(20)) + { + self.samples.push_back((now, state.pos())); + } + + while let Some(first) = self.samples.front() { + if now - first.0 > Duration::from_secs(1) { + self.samples.pop_front(); + } else { + break; + } + } + } + + fn reset(&mut self, _state: &indicatif::ProgressState, _now: std::time::Instant) { + self.samples = Default::default(); + } + + fn write(&self, _state: &indicatif::ProgressState, w: &mut dyn std::fmt::Write) { + match (self.samples.front(), self.samples.back()) { + (Some((t0, p0)), Some((t1, p1))) if self.samples.len() > 1 => { + let elapsed_ms = (*t1 - *t0).as_millis(); + let rate = ((p1 - p0) as f64 * 1000f64 / elapsed_ms as f64) as u64; + write!(w, "{}/s", HumanBytes(rate)).unwrap() + } + _ => write!(w, "-").unwrap(), + } + } +} diff --git a/src/api/sync.rs b/src/api/sync.rs index 3748c19..a4032c7 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -28,6 +28,9 @@ const AUTHORIZATION: &str = "Authorization"; type HeaderMap = HashMap<&'static str, String>; type HeaderName = &'static str; +/// Specific name for the sync part of the resumable file +const EXTENTION: &str = ".part"; + struct Wrapper<'a, P: Progress, R: Read> { progress: &'a mut P, inner: R, @@ -104,6 +107,10 @@ pub enum ApiError { #[error("Native tls: {0}")] #[cfg(feature = "native-tls")] Native(#[from] native_tls::Error), + + /// The part file is corrupted + #[error("Invalid part file - corrupted file")] + InvalidResume, } /// Helper to create [`Api`] with all the options. @@ -436,15 +443,26 @@ impl Api { url: &str, size: usize, mut progress: P, + tmp_path: PathBuf, filename: &str, ) -> Result { progress.init(size, filename); - let filepath = self.cache.temp_path(); + let filepath = tmp_path; // Create the file and set everything properly - let mut file = std::fs::File::create(&filepath)?; - let mut res = self.download_from(url, 0u64, size, &mut file, filename, &mut progress); + let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) { + Ok(f) => f, + Err(_) => std::fs::File::create(&filepath)?, + }; + + // In case of resume. + let start = file.metadata()?.len(); + if start > size as u64 { + return Err(ApiError::InvalidResume); + } + + let mut res = self.download_from(url, start, size, &mut file, filename, &mut progress); if self.max_retries > 0 { let mut i = 0; while let Err(dlerr) = res { @@ -631,9 +649,11 @@ impl ApiRepo { .blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - let tmp_filename = self - .api - .download_tempfile(&url, metadata.size, progress, filename)?; + let mut tmp_path = blob_path.clone(); + tmp_path.set_extension(EXTENTION); + let tmp_filename = + self.api + .download_tempfile(&url, metadata.size, progress, tmp_path, filename)?; std::fs::rename(tmp_filename, &blob_path)?; let mut pointer_path = self @@ -704,6 +724,7 @@ mod tests { use rand::{distributions::Alphanumeric, Rng}; use serde_json::{json, Value}; use sha2::{Digest, Sha256}; + use std::io::{Seek, SeekFrom, Write}; struct TempDir { path: PathBuf, @@ -756,6 +777,85 @@ mod tests { assert_eq!(cache_path, downloaded_path); } + #[test] + fn resume() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .build() + .unwrap(); + + let model_id = "julien-c/dummy-unknown".to_string(); + let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); + assert!(downloaded_path.exists()); + let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + + let blob = std::fs::canonicalize(&downloaded_path).unwrap(); + let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); + let size = file.metadata().unwrap().len(); + let truncate: f32 = rand::random(); + let new_size = (size as f32 * truncate) as u64; + file.set_len(new_size).unwrap(); + let mut blob_part = blob.clone(); + blob_part.set_extension(".part"); + std::fs::rename(blob, &blob_part).unwrap(); + std::fs::remove_file(&downloaded_path).unwrap(); + let content = std::fs::read(&*blob_part).unwrap(); + assert_eq!(content.len() as u64, new_size); + let val = Sha256::digest(content); + // We modified the sha. + assert!( + val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); + let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); + assert_eq!(downloaded_path, new_downloaded_path); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + + // Here we prove the previous part was correctly resuming by purposefully corrupting the + // file. + let blob = std::fs::canonicalize(&downloaded_path).unwrap(); + let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); + let size = file.metadata().unwrap().len(); + // Not random for consistent sha corruption + let truncate: f32 = 0.5; + let new_size = (size as f32 * truncate) as u64; + // Truncating + file.set_len(new_size).unwrap(); + // Corrupting by changing a single byte. + file.seek(SeekFrom::Start(new_size - 1)).unwrap(); + file.write_all(&[0]).unwrap(); + + let mut blob_part = blob.clone(); + blob_part.set_extension(".part"); + std::fs::rename(blob, &blob_part).unwrap(); + std::fs::remove_file(&downloaded_path).unwrap(); + let content = std::fs::read(&*blob_part).unwrap(); + assert_eq!(content.len() as u64, new_size); + let val = Sha256::digest(content); + // We modified the sha. + assert!( + val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); + let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); + println!("Sha {val:#x}"); + assert_eq!(downloaded_path, new_downloaded_path); + assert_eq!( + val[..], + // Corrupted sha + hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5") + ); + } + #[test] fn simple_with_retries() { let tmp = TempDir::new(); diff --git a/src/api/tokio.rs b/src/api/tokio.rs index 82f79c2..fbbe814 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -1,6 +1,7 @@ use super::Progress as SyncProgress; use super::{RepoInfo, HF_ENDPOINT}; use crate::{Cache, Repo, RepoType}; +use futures::stream::FuturesUnordered; use futures::StreamExt; use indicatif::ProgressBar; use rand::Rng; @@ -12,18 +13,24 @@ use reqwest::{ redirect::Policy, Client, Error as ReqwestError, RequestBuilder, }; +use std::cmp::Reverse; +use std::collections::BinaryHeap; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; use std::sync::Arc; use thiserror::Error; +use tokio::io::AsyncReadExt; use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; +use tokio::task::JoinError; /// Current version (used in user-agent) const VERSION: &str = env!("CARGO_PKG_VERSION"); /// Current name (used in user-agent) const NAME: &str = env!("CARGO_PKG_NAME"); +const EXTENTION: &str = ".sync.part"; + /// This trait is used by users of the lib /// to implement custom behavior during file downloads pub trait Progress { @@ -101,6 +108,9 @@ pub enum ApiError { // /// Semaphore cannot be acquired // #[error("Invalid Response: {0:?}")] // InvalidResponse(Response), + /// Join failed + #[error("Join: {0}")] + Join(#[from] JoinError), } /// Helper to create [`Api`] with all the options. @@ -182,8 +192,8 @@ impl ApiBuilder { cache, token, max_files: 1, - // chunk_size: 10_000_000, - chunk_size: None, + // We need to have some chunk size for things to be able to resume. + chunk_size: Some(10_000_000), parallel_failures: 0, max_retries: 0, progress, @@ -507,32 +517,61 @@ impl ApiRepo { &self, url: &str, length: usize, + filename: PathBuf, mut progressbar: P, ) -> Result { - let mut handles = vec![]; let semaphore = Arc::new(Semaphore::new(self.api.max_files)); let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures)); - let filename = self.api.cache.temp_path(); // Create the file and set everything properly - tokio::fs::File::create(&filename) - .await? - .set_len(length as u64) - .await?; + const N_BYTES: usize = size_of::(); + let start = match tokio::fs::OpenOptions::new() + .read(true) + .open(&filename) + .await + { + Ok(mut f) => { + let len = f.metadata().await?.len(); + if len == (length + N_BYTES) as u64 { + f.seek(SeekFrom::Start(length as u64)).await.unwrap(); + let mut buf = [0u8; N_BYTES]; + let n = f.read(buf.as_mut_slice()).await?; + if n == N_BYTES { + let committed = u64::from_le_bytes(buf); + committed as usize + } else { + 0 + } + } else { + 0 + } + } + Err(_err) => { + tokio::fs::File::create(&filename) + .await? + .set_len((length + N_BYTES) as u64) + .await?; + 0 + } + }; + progressbar.update(start).await; let chunk_size = self.api.chunk_size.unwrap_or(length); - for start in (0..length).step_by(chunk_size) { + let n_chunks = length / chunk_size; + let mut handles = Vec::with_capacity(n_chunks); + for start in (start..length).step_by(chunk_size) { let url = url.to_string(); let filename = filename.clone(); let client = self.api.client.clone(); let stop = std::cmp::min(start + chunk_size - 1, length); - let permit = semaphore.clone().acquire_owned().await?; + let permit = semaphore.clone(); let parallel_failures = self.api.parallel_failures; let max_retries = self.api.max_retries; let parallel_failures_semaphore = parallel_failures_semaphore.clone(); let progress = progressbar.clone(); handles.push(tokio::spawn(async move { + let permit = permit.acquire_owned().await?; let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop, progress.clone()) .await; @@ -563,18 +602,43 @@ impl ApiRepo { } } drop(permit); - // if let Some(p) = progress { - // progress.update(stop - start).await; - // } chunk })); } - // Output the chained result - let results: Vec, tokio::task::JoinError>> = - futures::future::join_all(handles).await; - let results: Result<(), ApiError> = results.into_iter().flatten().collect(); - results?; + let mut futures: FuturesUnordered<_> = handles.into_iter().collect(); + let mut temporaries = BinaryHeap::new(); + let mut committed: u64 = start as u64; + while let Some(chunk) = futures.next().await { + let chunk = chunk?; + let (start, stop) = chunk?; + temporaries.push(Reverse((start, stop))); + + let mut modified = false; + while let Some(Reverse((min, max))) = temporaries.pop() { + if min as u64 == committed { + committed = max as u64 + 1; + modified = true; + } else { + temporaries.push(Reverse((min, max))); + break; + } + } + if modified { + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .open(&filename) + .await?; + file.seek(SeekFrom::Start(length as u64)).await?; + file.write_all(&committed.to_le_bytes()).await?; + } + } + tokio::fs::OpenOptions::new() + .write(true) + .open(&filename) + .await? + .set_len(length as u64) + .await?; progressbar.finish().await; Ok(filename) } @@ -586,17 +650,12 @@ impl ApiRepo { start: usize, stop: usize, mut progress: P, - ) -> Result<(), ApiError> + ) -> Result<(usize, usize), ApiError> where P: Progress, { // Process each socket concurrently. let range = format!("bytes={start}-{stop}"); - let mut file = tokio::fs::OpenOptions::new() - .write(true) - .open(filename) - .await?; - file.seek(SeekFrom::Start(start as u64)).await?; let response = client .get(url) .header(RANGE, range) @@ -604,12 +663,19 @@ impl ApiRepo { .await? .error_for_status()?; let mut byte_stream = response.bytes_stream(); + let mut buf: Vec = Vec::with_capacity(stop - start); while let Some(next) = byte_stream.next().await { let next = next?; - file.write_all(&next).await?; + buf.extend(&next); progress.update(next.len()).await; } - Ok(()) + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .open(filename) + .await?; + file.seek(SeekFrom::Start(start as u64)).await?; + file.write_all(&buf).await?; + Ok((start, stop)) } /// This will attempt the fetch the file locally first, then [`Api.download`] @@ -694,9 +760,12 @@ impl ApiRepo { std::fs::create_dir_all(blob_path.parent().unwrap())?; progress.init(metadata.size, filename).await; + let mut tmp_path = blob_path.clone(); + tmp_path.set_extension(EXTENTION); let tmp_filename = self - .download_tempfile(&url, metadata.size, progress) - .await?; + .download_tempfile(&url, metadata.size, tmp_path, progress) + .await + .unwrap(); tokio::fs::rename(&tmp_filename, &blob_path).await?; @@ -749,6 +818,7 @@ mod tests { use rand::distributions::Alphanumeric; use serde_json::{json, Value}; use sha2::{Digest, Sha256}; + use std::io::{Seek, Write}; struct TempDir { path: PathBuf, @@ -797,6 +867,138 @@ mod tests { assert_eq!(cache_path, downloaded_path); } + #[tokio::test] + async fn resume() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .build() + .unwrap(); + let model_id = "julien-c/dummy-unknown".to_string(); + let downloaded_path = api + .model(model_id.clone()) + .download("config.json") + .await + .unwrap(); + assert!(downloaded_path.exists()); + let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + + // This actually sets the file to a trashed version of the part file, full redownload will + // ensue + let blob = std::fs::canonicalize(&downloaded_path).unwrap(); + let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); + let size = file.metadata().unwrap().len(); + let truncate: f32 = rand::random(); + let new_size = (size as f32 * truncate) as u64; + file.set_len(new_size).unwrap(); + let mut blob_part = blob.clone(); + blob_part.set_extension(".sync.part"); + std::fs::rename(blob, &blob_part).unwrap(); + std::fs::remove_file(&downloaded_path).unwrap(); + let content = std::fs::read(&*blob_part).unwrap(); + assert_eq!(content.len() as u64, new_size); + let val = Sha256::digest(content); + // We modified the sha. + assert!( + val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + let new_downloaded_path = api + .model(model_id.clone()) + .download("config.json") + .await + .unwrap(); + let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); + assert_eq!(downloaded_path, new_downloaded_path); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + + // Now this is a valid partial download file + let blob = std::fs::canonicalize(&downloaded_path).unwrap(); + let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); + let size = file.metadata().unwrap().len(); + let truncate: f32 = rand::random(); + let new_size = (size as f32 * truncate) as u64; + // Truncating + file.set_len(new_size).unwrap(); + let total_size = size + size_of::() as u64; + file.set_len(total_size).unwrap(); + file.seek(SeekFrom::Start(size)).unwrap(); + file.write_all(&new_size.to_le_bytes()).unwrap(); + + let mut blob_part = blob.clone(); + blob_part.set_extension(".sync.part"); + std::fs::rename(blob, &blob_part).unwrap(); + std::fs::remove_file(&downloaded_path).unwrap(); + let content = std::fs::read(&*blob_part).unwrap(); + assert_eq!(content.len() as u64, total_size); + let val = Sha256::digest(content); + // We modified the sha. + assert!( + val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + let new_downloaded_path = api + .model(model_id.clone()) + .download("config.json") + .await + .unwrap(); + let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); + assert_eq!(downloaded_path, new_downloaded_path); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + + // Here we prove the previous part was correctly resuming by purposefully corrupting the + // file. + let blob = std::fs::canonicalize(&downloaded_path).unwrap(); + let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); + let size = file.metadata().unwrap().len(); + // Not random for consistent sha corruption + let truncate: f32 = 0.5; + let new_size = (size as f32 * truncate) as u64; + // Truncating + file.set_len(new_size).unwrap(); + let total_size = size + size_of::() as u64; + file.set_len(total_size).unwrap(); + file.seek(SeekFrom::Start(size)).unwrap(); + file.write_all(&new_size.to_le_bytes()).unwrap(); + + // Corrupting by changing a single byte. + file.seek(SeekFrom::Start(new_size - 1)).unwrap(); + file.write_all(&[0]).unwrap(); + + let mut blob_part = blob.clone(); + blob_part.set_extension(".sync.part"); + std::fs::rename(blob, &blob_part).unwrap(); + std::fs::remove_file(&downloaded_path).unwrap(); + let content = std::fs::read(&*blob_part).unwrap(); + assert_eq!(content.len() as u64, total_size); + let val = Sha256::digest(content); + // We modified the sha. + assert!( + val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + let new_downloaded_path = api + .model(model_id.clone()) + .download("config.json") + .await + .unwrap(); + let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); + assert_eq!(downloaded_path, new_downloaded_path); + assert_eq!( + val[..], + // Corrupted sha + hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5") + ); + } + #[tokio::test] async fn revision() { let tmp = TempDir::new(); diff --git a/src/lib.rs b/src/lib.rs index a832c28..2f1c93c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,8 +4,6 @@ not(feature = "ureq"), doc = "Documentation is meant to be compiled with default features (at least ureq)" )] -#[cfg(any(feature = "tokio", feature = "ureq"))] -use rand::{distributions::Alphanumeric, Rng}; use std::io::Write; use std::path::PathBuf; @@ -124,21 +122,6 @@ impl Cache { pub fn space(&self, model_id: String) -> CacheRepo { self.repo(Repo::new(model_id, RepoType::Space)) } - - #[cfg(any(feature = "tokio", feature = "ureq"))] - pub(crate) fn temp_path(&self) -> PathBuf { - let mut path = self.path().clone(); - path.push("tmp"); - std::fs::create_dir_all(&path).ok(); - - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - path.push(s); - path.to_path_buf() - } } /// Shorthand for accessing things within a particular repo