Skip to content

Commit

Permalink
Merge pull request #395 from robertknight/cache-position-input
Browse files Browse the repository at this point in the history
Support `cache_position` inputs in Hugging Face models
  • Loading branch information
robertknight authored Oct 27, 2024
2 parents 63634b8 + 372d643 commit b304688
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,19 @@ pub struct ModelInputsConfig<'a> {
/// Model input that contains an attention mask.
pub attention_mask: &'a str,

/// Model input that contains position IDs for each position.
pub position_ids: &'a str,
/// Model input that contains KV cache positions for each position.
///
/// This input does not have a batch dimension.
pub cache_position: &'a str,

/// Patterns for inputs and outputs used for key-value caches.
pub kv_caches: Vec<KVCachePair<'a>>,

/// Model input that contains position IDs for each position.
///
/// This input has a batch dimension.
pub position_ids: &'a str,

/// Boolean input that is set to false on the first run and true on
/// subsequent runs.
pub use_cache_flag: &'a str,
Expand All @@ -159,6 +166,7 @@ impl<'a> Default for ModelInputsConfig<'a> {
input_ids: "input_ids",
logits: "logits",
attention_mask: "attention_mask",
cache_position: "cache_position",
position_ids: "position_ids",
use_cache_flag: "use_cache_branch",

Expand Down Expand Up @@ -317,6 +325,7 @@ impl<'a> Generator<'a> {
/// The model may have the optional inputs:
///
/// - `attention_mask` - (batch, sequence) tensor of booleans
/// - `cache_position` - (sequence) tensor of KV-cache positions. Usually the same as `position_ids`
/// - `position_ids` - (batch, sequence) tensor of position indices
/// - `past_key_values.N.key` - (batch, head, past_seq_len, size) key vector cache
/// where `N` is the layer index
Expand Down Expand Up @@ -480,6 +489,15 @@ impl<'a> Generator<'a> {
});
}

let cache_position_input = model.find_node(model_inputs.cache_position);
if let Some(cache_position_input) = cache_position_input {
generator =
generator.with_varying_input(cache_position_input, &|_batch_size, positions| {
NdTensor::from_fn([positions.len()], |[pos]| (positions.start + pos) as i32)
.into()
});
}

let use_cache_input = model.find_node(model_inputs.use_cache_flag);
if let Some(use_cache_input) = use_cache_input {
generator = generator.with_varying_input(use_cache_input, &|_batch_size, positions| {
Expand Down Expand Up @@ -981,6 +999,7 @@ mod tests {
// Add inputs and outputs using the standard names.
let mut inputs = vec![
NodeInfo::from_name_shape("input_ids", &[]),
NodeInfo::from_name_shape("cache_position", &[]),
NodeInfo::from_name_shape("position_ids", &[]),
NodeInfo::from_name_shape("attention_mask", &[]),
];
Expand Down Expand Up @@ -1125,6 +1144,7 @@ mod tests {
let position_ids = model.find_node("position_ids").unwrap();
let attention_mask = model.find_node("attention_mask").unwrap();
let cache_branch = model.find_node("use_cache_branch");
let cache_position = model.find_node("cache_position").unwrap();

for step in 0..generation_len {
let step_inputs = model.get_inputs(step, input_id).unwrap();
Expand All @@ -1133,6 +1153,9 @@ mod tests {
let step_pos_ids = model.get_inputs(step, position_ids).unwrap();
let step_pos_ids: NdTensor<i32, 2> = step_pos_ids.try_into().unwrap();

let step_cache_pos = model.get_inputs(step, cache_position).unwrap();
let step_cache_pos: NdTensor<i32, 1> = step_cache_pos.try_into().unwrap();

let step_attn_mask = model.get_inputs(step, attention_mask).unwrap();
let step_attn_mask: NdTensor<i32, 2> = step_attn_mask.try_into().unwrap();

Expand All @@ -1155,6 +1178,12 @@ mod tests {
assert_eq!(step_pos_ids.size(1), prompt.len());
assert!(step_pos_ids.iter().map(|x| *x as usize).eq(0..prompt.len()));

assert_eq!(step_cache_pos.size(0), prompt.len());
assert!(step_cache_pos
.iter()
.map(|x| *x as usize)
.eq(0..prompt.len()));

if let Some(cache_branch) = cache_branch {
assert_eq!(cache_branch.item(), Some(&0));
}
Expand All @@ -1168,6 +1197,9 @@ mod tests {
assert_eq!(step_pos_ids.size(1), 1);
assert_eq!(step_pos_ids[[0, 0]], (prompt.len() + step - 1) as i32);

assert_eq!(step_cache_pos.size(0), 1);
assert_eq!(step_cache_pos[[0]], (prompt.len() + step - 1) as i32);

if let Some(cache_branch) = cache_branch {
assert_eq!(cache_branch.item(), Some(&1));
}
Expand All @@ -1194,7 +1226,11 @@ mod tests {
(0..prompt.len() + step).map(|x| x as i32).collect();
assert_eq!(
step_pos_ids,
NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids)
NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids.clone())
);
assert_eq!(
step_cache_pos,
NdTensor::from_data([expected_pos_ids.len()], expected_pos_ids)
);
}
}
Expand Down

0 comments on commit b304688

Please sign in to comment.