From 372d643ef90e886918541b2f9472819f06a7d53a Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 27 Oct 2024 10:24:56 +0000 Subject: [PATCH] Support `cache_position` inputs in Hugging Face models Support the `cache_position` input that was added to Hugging Face Whisper models as part of a revision of how it handles KV-caching. This is like `position_ids`, but there is no batch dimension. See https://github.com/huggingface/optimum/pull/1971 and https://github.com/huggingface/transformers/pull/31166. --- rten-generate/src/generator.rs | 42 +++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/rten-generate/src/generator.rs b/rten-generate/src/generator.rs index 3396c625..47b1f051 100644 --- a/rten-generate/src/generator.rs +++ b/rten-generate/src/generator.rs @@ -131,12 +131,19 @@ pub struct ModelInputsConfig<'a> { /// Model input that contains an attention mask. pub attention_mask: &'a str, - /// Model input that contains position IDs for each position. - pub position_ids: &'a str, + /// Model input that contains KV cache positions for each position. + /// + /// This input does not have a batch dimension. + pub cache_position: &'a str, /// Patterns for inputs and outputs used for key-value caches. pub kv_caches: Vec>, + /// Model input that contains position IDs for each position. + /// + /// This input has a batch dimension. + pub position_ids: &'a str, + /// Boolean input that is set to false on the first run and true on /// subsequent runs. pub use_cache_flag: &'a str, @@ -159,6 +166,7 @@ impl<'a> Default for ModelInputsConfig<'a> { input_ids: "input_ids", logits: "logits", attention_mask: "attention_mask", + cache_position: "cache_position", position_ids: "position_ids", use_cache_flag: "use_cache_branch", @@ -317,6 +325,7 @@ impl<'a> Generator<'a> { /// The model may have the optional inputs: /// /// - `attention_mask` - (batch, sequence) tensor of booleans + /// - `cache_position` - (sequence) tensor of KV-cache positions. Usually the same as `position_ids` /// - `position_ids` - (batch, sequence) tensor of position indices /// - `past_key_values.N.key` - (batch, head, past_seq_len, size) key vector cache /// where `N` is the layer index @@ -480,6 +489,15 @@ impl<'a> Generator<'a> { }); } + let cache_position_input = model.find_node(model_inputs.cache_position); + if let Some(cache_position_input) = cache_position_input { + generator = + generator.with_varying_input(cache_position_input, &|_batch_size, positions| { + NdTensor::from_fn([positions.len()], |[pos]| (positions.start + pos) as i32) + .into() + }); + } + let use_cache_input = model.find_node(model_inputs.use_cache_flag); if let Some(use_cache_input) = use_cache_input { generator = generator.with_varying_input(use_cache_input, &|_batch_size, positions| { @@ -981,6 +999,7 @@ mod tests { // Add inputs and outputs using the standard names. let mut inputs = vec![ NodeInfo::from_name_shape("input_ids", &[]), + NodeInfo::from_name_shape("cache_position", &[]), NodeInfo::from_name_shape("position_ids", &[]), NodeInfo::from_name_shape("attention_mask", &[]), ]; @@ -1125,6 +1144,7 @@ mod tests { let position_ids = model.find_node("position_ids").unwrap(); let attention_mask = model.find_node("attention_mask").unwrap(); let cache_branch = model.find_node("use_cache_branch"); + let cache_position = model.find_node("cache_position").unwrap(); for step in 0..generation_len { let step_inputs = model.get_inputs(step, input_id).unwrap(); @@ -1133,6 +1153,9 @@ mod tests { let step_pos_ids = model.get_inputs(step, position_ids).unwrap(); let step_pos_ids: NdTensor = step_pos_ids.try_into().unwrap(); + let step_cache_pos = model.get_inputs(step, cache_position).unwrap(); + let step_cache_pos: NdTensor = step_cache_pos.try_into().unwrap(); + let step_attn_mask = model.get_inputs(step, attention_mask).unwrap(); let step_attn_mask: NdTensor = step_attn_mask.try_into().unwrap(); @@ -1155,6 +1178,12 @@ mod tests { assert_eq!(step_pos_ids.size(1), prompt.len()); assert!(step_pos_ids.iter().map(|x| *x as usize).eq(0..prompt.len())); + assert_eq!(step_cache_pos.size(0), prompt.len()); + assert!(step_cache_pos + .iter() + .map(|x| *x as usize) + .eq(0..prompt.len())); + if let Some(cache_branch) = cache_branch { assert_eq!(cache_branch.item(), Some(&0)); } @@ -1168,6 +1197,9 @@ mod tests { assert_eq!(step_pos_ids.size(1), 1); assert_eq!(step_pos_ids[[0, 0]], (prompt.len() + step - 1) as i32); + assert_eq!(step_cache_pos.size(0), 1); + assert_eq!(step_cache_pos[[0]], (prompt.len() + step - 1) as i32); + if let Some(cache_branch) = cache_branch { assert_eq!(cache_branch.item(), Some(&1)); } @@ -1194,7 +1226,11 @@ mod tests { (0..prompt.len() + step).map(|x| x as i32).collect(); assert_eq!( step_pos_ids, - NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids) + NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids.clone()) + ); + assert_eq!( + step_cache_pos, + NdTensor::from_data([expected_pos_ids.len()], expected_pos_ids) ); } }