diff --git a/rten-generate/src/generator.rs b/rten-generate/src/generator.rs index 3396c625..47b1f051 100644 --- a/rten-generate/src/generator.rs +++ b/rten-generate/src/generator.rs @@ -131,12 +131,19 @@ pub struct ModelInputsConfig<'a> { /// Model input that contains an attention mask. pub attention_mask: &'a str, - /// Model input that contains position IDs for each position. - pub position_ids: &'a str, + /// Model input that contains KV cache positions for each position. + /// + /// This input does not have a batch dimension. + pub cache_position: &'a str, /// Patterns for inputs and outputs used for key-value caches. pub kv_caches: Vec>, + /// Model input that contains position IDs for each position. + /// + /// This input has a batch dimension. + pub position_ids: &'a str, + /// Boolean input that is set to false on the first run and true on /// subsequent runs. pub use_cache_flag: &'a str, @@ -159,6 +166,7 @@ impl<'a> Default for ModelInputsConfig<'a> { input_ids: "input_ids", logits: "logits", attention_mask: "attention_mask", + cache_position: "cache_position", position_ids: "position_ids", use_cache_flag: "use_cache_branch", @@ -317,6 +325,7 @@ impl<'a> Generator<'a> { /// The model may have the optional inputs: /// /// - `attention_mask` - (batch, sequence) tensor of booleans + /// - `cache_position` - (sequence) tensor of KV-cache positions. Usually the same as `position_ids` /// - `position_ids` - (batch, sequence) tensor of position indices /// - `past_key_values.N.key` - (batch, head, past_seq_len, size) key vector cache /// where `N` is the layer index @@ -480,6 +489,15 @@ impl<'a> Generator<'a> { }); } + let cache_position_input = model.find_node(model_inputs.cache_position); + if let Some(cache_position_input) = cache_position_input { + generator = + generator.with_varying_input(cache_position_input, &|_batch_size, positions| { + NdTensor::from_fn([positions.len()], |[pos]| (positions.start + pos) as i32) + .into() + }); + } + let use_cache_input = model.find_node(model_inputs.use_cache_flag); if let Some(use_cache_input) = use_cache_input { generator = generator.with_varying_input(use_cache_input, &|_batch_size, positions| { @@ -981,6 +999,7 @@ mod tests { // Add inputs and outputs using the standard names. let mut inputs = vec![ NodeInfo::from_name_shape("input_ids", &[]), + NodeInfo::from_name_shape("cache_position", &[]), NodeInfo::from_name_shape("position_ids", &[]), NodeInfo::from_name_shape("attention_mask", &[]), ]; @@ -1125,6 +1144,7 @@ mod tests { let position_ids = model.find_node("position_ids").unwrap(); let attention_mask = model.find_node("attention_mask").unwrap(); let cache_branch = model.find_node("use_cache_branch"); + let cache_position = model.find_node("cache_position").unwrap(); for step in 0..generation_len { let step_inputs = model.get_inputs(step, input_id).unwrap(); @@ -1133,6 +1153,9 @@ mod tests { let step_pos_ids = model.get_inputs(step, position_ids).unwrap(); let step_pos_ids: NdTensor = step_pos_ids.try_into().unwrap(); + let step_cache_pos = model.get_inputs(step, cache_position).unwrap(); + let step_cache_pos: NdTensor = step_cache_pos.try_into().unwrap(); + let step_attn_mask = model.get_inputs(step, attention_mask).unwrap(); let step_attn_mask: NdTensor = step_attn_mask.try_into().unwrap(); @@ -1155,6 +1178,12 @@ mod tests { assert_eq!(step_pos_ids.size(1), prompt.len()); assert!(step_pos_ids.iter().map(|x| *x as usize).eq(0..prompt.len())); + assert_eq!(step_cache_pos.size(0), prompt.len()); + assert!(step_cache_pos + .iter() + .map(|x| *x as usize) + .eq(0..prompt.len())); + if let Some(cache_branch) = cache_branch { assert_eq!(cache_branch.item(), Some(&0)); } @@ -1168,6 +1197,9 @@ mod tests { assert_eq!(step_pos_ids.size(1), 1); assert_eq!(step_pos_ids[[0, 0]], (prompt.len() + step - 1) as i32); + assert_eq!(step_cache_pos.size(0), 1); + assert_eq!(step_cache_pos[[0]], (prompt.len() + step - 1) as i32); + if let Some(cache_branch) = cache_branch { assert_eq!(cache_branch.item(), Some(&1)); } @@ -1194,7 +1226,11 @@ mod tests { (0..prompt.len() + step).map(|x| x as i32).collect(); assert_eq!( step_pos_ids, - NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids) + NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids.clone()) + ); + assert_eq!( + step_cache_pos, + NdTensor::from_data([expected_pos_ids.len()], expected_pos_ids) ); } }