Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to rten-generate for filtering/processing logits #393

Merged
merged 1 commit into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions rten-generate/src/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//! Filters for processing model outputs prior to sampling.

use rten_tensor::{NdTensor, NdTensorView};

/// Filter which modifies the output logits from a model.
///
/// The filter is applied to the model outputs before a token is sampled.
pub trait LogitsFilter {
/// Filter the model's output and return the modified logits.
///
/// If this method returns `None`, the input logits are passed unmodified
/// to the sampler. `prev_tokens` contains the previously sampled tokens,
/// including the prompt.
fn filter(&self, logits: NdTensorView<f32, 1>, prev_tokens: &[u32])
-> Option<NdTensor<f32, 1>>;
}
142 changes: 125 additions & 17 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rten_tensor::{NdTensor, Tensor};
#[cfg(feature = "text-decoder")]
use rten_text::tokenizers::{Tokenizer, TokenizerError};

use crate::filter::LogitsFilter;
use crate::metrics::Metrics;
use crate::model::Model;
use crate::sampler::{ArgMaxSampler, Sampler};
Expand Down Expand Up @@ -216,7 +217,26 @@ impl<'a> Default for ModelInputsConfig<'a> {
/// such as stopping generation when an end-of-text token is reached. You can
/// also use all of the standard iterator adapters. For example
/// `generator.take(30)` will return an iterator that stops generation after 30
/// tokens have been produced).
/// tokens have been produced.
///
/// ## Processing pipeline
///
/// Each call to [`next`](Iterator::next) performs the following steps:
///
/// 1. Compute model inputs. This includes the token IDs (prompt on first run,
/// most recently sampled token after that), constant inputs and
/// position-varying inputs (eg. attention mask).
/// 2. Run the model
/// 3. Apply filters to model outputs ("logits")
/// 4. Sample a token from the logits
/// 5. Save the sampled token and KV-caches from the model for the next
/// generation step.
/// 6. Return the sampled token as the iterator output
///
/// ## Logit filters
///
/// The raw model outputs can be modified before sampling by configuring a
/// [`LogitsFilter`] using [`with_logits_filter`](Generator::with_logits_filter).
///
/// ## Sampling
///
Expand Down Expand Up @@ -254,17 +274,24 @@ pub struct Generator<'a> {
/// Input token IDs for the next run of the model.
input_ids: Vec<TokenId>,

// Input node IDs
/// Position ID associated with the first token in `input_ids`.
input_offset: usize,

/// Input node IDs
input_ids_input: NodeId,

// Output node IDs
/// Output node IDs
logits_output: NodeId,

// Sampler used to get the next token ID from the output logits.
/// Filter used to modify logits before sampling.
logits_filter: Option<Box<dyn LogitsFilter>>,

/// Sampler used to get the next token ID from the output logits.
sampler: Box<dyn Sampler>,

/// Length of the sequence generated so far.
seq_len: u32,
/// Previously sampled tokens. These are retained for conditional filtering
/// and sampling.
prev_tokens: Vec<u32>,

/// Self-attention key-value cache. This is extended on each iteration.
kv_cache: Vec<KvCache>,
Expand Down Expand Up @@ -423,12 +450,14 @@ impl<'a> Generator<'a> {
// constant inputs are added.
constant_prop_inputs: Some(Vec::new()),

logits_filter: None,
input_ids: vec![],
input_ids_input,
input_offset: 0,
logits_output,
kv_cache,
encoder_kv_cache,
seq_len: 0,
prev_tokens: Vec::new(),
sampler: Box::new(ArgMaxSampler {}),
};

Expand Down Expand Up @@ -505,6 +534,13 @@ impl<'a> Generator<'a> {
self
}

/// Set the filter used to process model output logits before passing them
/// to the sampler to select a token ID.
pub fn with_logits_filter<F: LogitsFilter + 'static>(mut self, filter: F) -> Self {
self.logits_filter = Some(Box::new(filter));
self
}

/// Set the sampler used to sample the next token ID from the output logits.
pub fn with_sampler<S: Sampler + 'static>(mut self, sampler: S) -> Self {
self.sampler = Box::new(sampler);
Expand Down Expand Up @@ -534,7 +570,7 @@ impl<'a> Generator<'a> {
.collect::<Tensor<_>>()
.into_shape([batch_size, self.input_ids.len()]);

let seq_range = (self.seq_len as usize)..(self.seq_len as usize + self.input_ids.len());
let input_positions = self.input_offset..self.input_offset + self.input_ids.len();

let mut model_inputs: Vec<(NodeId, InputOrOutput)> =
vec![(self.input_ids_input, input_ids.view().into())];
Expand Down Expand Up @@ -563,11 +599,9 @@ impl<'a> Generator<'a> {
}

if !self.varying_inputs.is_empty() {
model_inputs.extend(
self.varying_inputs
.iter()
.map(|(node_id, value_fn)| (*node_id, value_fn(batch_size, seq_range.clone()))),
);
model_inputs.extend(self.varying_inputs.iter().map(|(node_id, value_fn)| {
(*node_id, value_fn(batch_size, input_positions.clone()))
}));
}

// Add key-value cache from previous run. The model takes ownership
Expand Down Expand Up @@ -611,9 +645,21 @@ impl<'a> Generator<'a> {
.run(model_inputs, &model_outputs, self.run_options.clone())
.map_err(wrap_error)?;

// Sample output token.
// Apply filtering to model outputs.
if self.prev_tokens.is_empty() {
self.prev_tokens.extend(self.input_ids.iter());
}
let logits: NdTensor<f32, 3> = outputs.remove(0).try_into().map_err(wrap_error)?;
let next_id = self.sampler.sample(logits.slice((0, -1)));
let last_logits = logits.slice((0, -1));
let filtered_logits = self
.logits_filter
.as_ref()
.and_then(|f| f.filter(last_logits, &self.prev_tokens))
.map(|l| l.into_cow())
.unwrap_or(last_logits.as_cow());

// Sample output token.
let next_id = self.sampler.sample(filtered_logits.view());

// Update the self-attention key-value cache.
//
Expand Down Expand Up @@ -654,8 +700,9 @@ impl<'a> Generator<'a> {
}

// Update the token IDs and sequence offset for the next iteration.
self.prev_tokens.push(next_id);
if !self.kv_cache.is_empty() {
self.seq_len += self.input_ids.len() as u32;
self.input_offset += self.input_ids.len();
self.input_ids = vec![next_id];
} else {
self.input_ids.push(next_id);
Expand Down Expand Up @@ -733,12 +780,14 @@ mod tests {
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
use std::error::Error;
use std::rc::Rc;

use rten::{Dimension, InputOrOutput, NodeId, Output, RunOptions};
use rten_tensor::prelude::*;
use rten_tensor::NdTensor;
use rten_tensor::{NdTensor, NdTensorView};

use super::{Generator, GeneratorUtils};
use crate::filter::LogitsFilter;
use crate::metrics::Metrics;
use crate::model::{Model, NodeInfo};

Expand Down Expand Up @@ -1263,6 +1312,65 @@ mod tests {
Ok(())
}

#[test]
fn test_filter() -> Result<(), Box<dyn Error>> {
let mut params = TransformerParams::default();
params.n_vocab = 8; // Must be >2x the max token ID in `expected_token_ids`.

let expected_token_ids = [0, 1, 2, 3];
let prompt = [5, 6, 7];
let model = fake_transformer_model(
params,
Some(KvCacheType::Decoder),
prompt.len(),
&expected_token_ids,
);

let generator = Generator::from_model(&model)?;

// Filter that modifies logits to double the selected token ID.
struct DoubleIndexFilter {
prev_tokens: Rc<RefCell<Vec<u32>>>,
}
impl LogitsFilter for DoubleIndexFilter {
fn filter(
&self,
logits: NdTensorView<f32, 1>,
prev_tokens: &[u32],
) -> Option<NdTensor<f32, 1>> {
self.prev_tokens.replace(prev_tokens.to_vec());

let max_idx = logits
.iter()
.enumerate()
.max_by(|(_i, x), (_j, y)| x.total_cmp(y))
.map(|(i, _x)| i)?;
Some(NdTensor::from_fn(logits.shape(), |[i]| {
if i == max_idx * 2 {
1.
} else {
0.
}
}))
}
}

let prev_tokens = Rc::new(RefCell::new(Vec::new()));
let output_token_ids: Vec<_> = generator
.with_prompt(&prompt)
.with_logits_filter(DoubleIndexFilter {
prev_tokens: prev_tokens.clone(),
})
.take(expected_token_ids.len())
.map(|id| id.expect("generation failed"))
.collect();

assert_eq!(output_token_ids, [0, 2, 4, 6]);
assert_eq!(prev_tokens.borrow().as_slice(), [5, 6, 7, 0, 2, 4]);

Ok(())
}

#[test]
fn test_run_options() -> Result<(), Box<dyn Error>> {
let params = TransformerParams::default();
Expand Down
1 change: 1 addition & 0 deletions rten-generate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! [rten]: https://github.com/robertknight/rten
//! [rten-examples]: https://github.com/robertknight/rten/tree/main/rten-examples

pub mod filter;
pub mod generator;
pub mod metrics;
pub mod model;
Expand Down
Loading