Skip to content

Commit

Permalink
Add matching for valid model weight names (#571)
Browse files Browse the repository at this point in the history
* Add matching for valid model weight names

* Fix clippy
  • Loading branch information
EricLBuehler authored Jul 13, 2024
1 parent 46bd23d commit d594708
Showing 1 changed file with 76 additions and 1 deletion.
77 changes: 76 additions & 1 deletion mistralrs-core/src/pipeline/paths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use hf_hub::{
api::sync::{ApiBuilder, ApiRepo},
Repo, RepoType,
};
use regex_automata::meta::Regex;
use serde_json::Value;
use tracing::{info, warn};

Expand All @@ -17,6 +18,10 @@ use crate::{
utils::tokens::get_token, xlora_models::XLoraConfig, ModelPaths, Ordering, TokenSource,
};

// Match files against these, avoids situations like `consolidated.safetensors`
const SAFETENSOR_MATCH: &str = r"model-\d{5}-of-\d{5}";
const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}";

pub(crate) struct XLoraPaths {
pub adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
pub adapter_safetensors: Option<Vec<(String, PathBuf)>>,
Expand Down Expand Up @@ -273,8 +278,13 @@ pub fn get_model_paths(
Ok(files)
}
None => {
// We only match these patterns for model names
let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
let pickle_match = Regex::new(PICKLE_MATCH)?;

let mut filenames = vec![];
let listing = api_dir_list!(api, model_id);
let listing = api_dir_list!(api, model_id)
.filter(|x| safetensor_match.is_match(x) || pickle_match.is_match(x));
let safetensors = listing
.clone()
.filter(|x| x.ends_with(".safetensors"))
Expand All @@ -292,6 +302,13 @@ pub fn get_model_paths(
} else {
anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
};
info!(
"Found model weight filenames {:?}",
files
.iter()
.map(|x| x.split('/').last().unwrap())
.collect::<Vec<_>>()
);
for rfilename in files {
filenames.push(api_get_file!(api, &rfilename, model_id));
}
Expand Down Expand Up @@ -420,3 +437,61 @@ pub(crate) fn get_chat_template(
}
}
}

mod tests {
#[test]
fn match_safetensors() -> anyhow::Result<()> {
use regex_automata::meta::Regex;

use super::SAFETENSOR_MATCH;
let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;

let positive_ids = [
"model-00001-of-00001.safetensors",
"model-00002-of-00002.safetensors",
"model-00003-of-00003.safetensors",
"model-00004-of-00004.safetensors",
"model-00005-of-00005.safetensors",
"model-00006-of-00006.safetensors",
];
let negative_ids = [
"model-000001-of-00001.safetensors",
"model-0000a-of-00002.safetensors",
"model-000-of-00003.safetensors",
"consolidated.safetensors",
];
for id in positive_ids {
assert!(safetensor_match.is_match(id));
}
for id in negative_ids {
assert!(!safetensor_match.is_match(id));
}
Ok(())
}

#[test]
fn match_pickle() -> anyhow::Result<()> {
use regex_automata::meta::Regex;

use super::PICKLE_MATCH;
let pickle_match = Regex::new(PICKLE_MATCH)?;

let positive_ids = [
"pytorch_model-00001-of-00002.bin",
"pytorch_model-00002-of-00002.bin",
];
let negative_ids = [
"pytorch_model-000001-of-00001.bin",
"pytorch_model-0000a-of-00002.bin",
"pytorch_model-000-of-00003.bin",
"pytorch_consolidated.bin",
];
for id in positive_ids {
assert!(pickle_match.is_match(id));
}
for id in negative_ids {
assert!(!pickle_match.is_match(id));
}
Ok(())
}
}

0 comments on commit d594708

Please sign in to comment.