Skip to content

Commit

Permalink
Merge pull request #393 from robertknight/logits-filter
Browse files Browse the repository at this point in the history
Add API to rten-generate for filtering/processing logits
  • Loading branch information
robertknight authored Oct 26, 2024
2 parents 72e4c7a + ab2b294 commit 77493b5
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 17 deletions.
16 changes: 16 additions & 0 deletions rten-generate/src/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//! Filters for processing model outputs prior to sampling.

use rten_tensor::{NdTensor, NdTensorView};

/// Filter which modifies the output logits from a model.
///
/// The filter is applied to the model outputs before a token is sampled.
pub trait LogitsFilter {
/// Filter the model's output and return the modified logits.
///
/// If this method returns `None`, the input logits are passed unmodified
/// to the sampler. `prev_tokens` contains the previously sampled tokens,
/// including the prompt.
fn filter(&self, logits: NdTensorView<f32, 1>, prev_tokens: &[u32])
-> Option<NdTensor<f32, 1>>;
}
142 changes: 125 additions & 17 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rten_tensor::{NdTensor, Tensor};
#[cfg(feature = "text-decoder")]
use rten_text::tokenizers::{Tokenizer, TokenizerError};

use crate::filter::LogitsFilter;
use crate::metrics::Metrics;
use crate::model::Model;
use crate::sampler::{ArgMaxSampler, Sampler};
Expand Down Expand Up @@ -216,7 +217,26 @@ impl<'a> Default for ModelInputsConfig<'a> {
/// such as stopping generation when an end-of-text token is reached. You can
/// also use all of the standard iterator adapters. For example
/// `generator.take(30)` will return an iterator that stops generation after 30
/// tokens have been produced).
/// tokens have been produced.
///
/// ## Processing pipeline
///
/// Each call to [`next`](Iterator::next) performs the following steps:
///
/// 1. Compute model inputs. This includes the token IDs (prompt on first run,
/// most recently sampled token after that), constant inputs and
/// position-varying inputs (eg. attention mask).
/// 2. Run the model
/// 3. Apply filters to model outputs ("logits")
/// 4. Sample a token from the logits
/// 5. Save the sampled token and KV-caches from the model for the next
/// generation step.
/// 6. Return the sampled token as the iterator output
///
/// ## Logit filters
///
/// The raw model outputs can be modified before sampling by configuring a
/// [`LogitsFilter`] using [`with_logits_filter`](Generator::with_logits_filter).
///
/// ## Sampling
///
Expand Down Expand Up @@ -254,17 +274,24 @@ pub struct Generator<'a> {
/// Input token IDs for the next run of the model.
input_ids: Vec<TokenId>,

// Input node IDs
/// Position ID associated with the first token in `input_ids`.
input_offset: usize,

/// Input node IDs
input_ids_input: NodeId,

// Output node IDs
/// Output node IDs
logits_output: NodeId,

// Sampler used to get the next token ID from the output logits.
/// Filter used to modify logits before sampling.
logits_filter: Option<Box<dyn LogitsFilter>>,

/// Sampler used to get the next token ID from the output logits.
sampler: Box<dyn Sampler>,

/// Length of the sequence generated so far.
seq_len: u32,
/// Previously sampled tokens. These are retained for conditional filtering
/// and sampling.
prev_tokens: Vec<u32>,

/// Self-attention key-value cache. This is extended on each iteration.
kv_cache: Vec<KvCache>,
Expand Down Expand Up @@ -423,12 +450,14 @@ impl<'a> Generator<'a> {
// constant inputs are added.
constant_prop_inputs: Some(Vec::new()),

logits_filter: None,
input_ids: vec![],
input_ids_input,
input_offset: 0,
logits_output,
kv_cache,
encoder_kv_cache,
seq_len: 0,
prev_tokens: Vec::new(),
sampler: Box::new(ArgMaxSampler {}),
};

Expand Down Expand Up @@ -505,6 +534,13 @@ impl<'a> Generator<'a> {
self
}

/// Set the filter used to process model output logits before passing them
/// to the sampler to select a token ID.
pub fn with_logits_filter<F: LogitsFilter + 'static>(mut self, filter: F) -> Self {
self.logits_filter = Some(Box::new(filter));
self
}

/// Set the sampler used to sample the next token ID from the output logits.
pub fn with_sampler<S: Sampler + 'static>(mut self, sampler: S) -> Self {
self.sampler = Box::new(sampler);
Expand Down Expand Up @@ -534,7 +570,7 @@ impl<'a> Generator<'a> {
.collect::<Tensor<_>>()
.into_shape([batch_size, self.input_ids.len()]);

let seq_range = (self.seq_len as usize)..(self.seq_len as usize + self.input_ids.len());
let input_positions = self.input_offset..self.input_offset + self.input_ids.len();

let mut model_inputs: Vec<(NodeId, InputOrOutput)> =
vec![(self.input_ids_input, input_ids.view().into())];
Expand Down Expand Up @@ -563,11 +599,9 @@ impl<'a> Generator<'a> {
}

if !self.varying_inputs.is_empty() {
model_inputs.extend(
self.varying_inputs
.iter()
.map(|(node_id, value_fn)| (*node_id, value_fn(batch_size, seq_range.clone()))),
);
model_inputs.extend(self.varying_inputs.iter().map(|(node_id, value_fn)| {
(*node_id, value_fn(batch_size, input_positions.clone()))
}));
}

// Add key-value cache from previous run. The model takes ownership
Expand Down Expand Up @@ -611,9 +645,21 @@ impl<'a> Generator<'a> {
.run(model_inputs, &model_outputs, self.run_options.clone())
.map_err(wrap_error)?;

// Sample output token.
// Apply filtering to model outputs.
if self.prev_tokens.is_empty() {
self.prev_tokens.extend(self.input_ids.iter());
}
let logits: NdTensor<f32, 3> = outputs.remove(0).try_into().map_err(wrap_error)?;
let next_id = self.sampler.sample(logits.slice((0, -1)));
let last_logits = logits.slice((0, -1));
let filtered_logits = self
.logits_filter
.as_ref()
.and_then(|f| f.filter(last_logits, &self.prev_tokens))
.map(|l| l.into_cow())
.unwrap_or(last_logits.as_cow());

// Sample output token.
let next_id = self.sampler.sample(filtered_logits.view());

// Update the self-attention key-value cache.
//
Expand Down Expand Up @@ -654,8 +700,9 @@ impl<'a> Generator<'a> {
}

// Update the token IDs and sequence offset for the next iteration.
self.prev_tokens.push(next_id);
if !self.kv_cache.is_empty() {
self.seq_len += self.input_ids.len() as u32;
self.input_offset += self.input_ids.len();
self.input_ids = vec![next_id];
} else {
self.input_ids.push(next_id);
Expand Down Expand Up @@ -733,12 +780,14 @@ mod tests {
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
use std::error::Error;
use std::rc::Rc;

use rten::{Dimension, InputOrOutput, NodeId, Output, RunOptions};
use rten_tensor::prelude::*;
use rten_tensor::NdTensor;
use rten_tensor::{NdTensor, NdTensorView};

use super::{Generator, GeneratorUtils};
use crate::filter::LogitsFilter;
use crate::metrics::Metrics;
use crate::model::{Model, NodeInfo};

Expand Down Expand Up @@ -1263,6 +1312,65 @@ mod tests {
Ok(())
}

#[test]
fn test_filter() -> Result<(), Box<dyn Error>> {
let mut params = TransformerParams::default();
params.n_vocab = 8; // Must be >2x the max token ID in `expected_token_ids`.

let expected_token_ids = [0, 1, 2, 3];
let prompt = [5, 6, 7];
let model = fake_transformer_model(
params,
Some(KvCacheType::Decoder),
prompt.len(),
&expected_token_ids,
);

let generator = Generator::from_model(&model)?;

// Filter that modifies logits to double the selected token ID.
struct DoubleIndexFilter {
prev_tokens: Rc<RefCell<Vec<u32>>>,
}
impl LogitsFilter for DoubleIndexFilter {
fn filter(
&self,
logits: NdTensorView<f32, 1>,
prev_tokens: &[u32],
) -> Option<NdTensor<f32, 1>> {
self.prev_tokens.replace(prev_tokens.to_vec());

let max_idx = logits
.iter()
.enumerate()
.max_by(|(_i, x), (_j, y)| x.total_cmp(y))
.map(|(i, _x)| i)?;
Some(NdTensor::from_fn(logits.shape(), |[i]| {
if i == max_idx * 2 {
1.
} else {
0.
}
}))
}
}

let prev_tokens = Rc::new(RefCell::new(Vec::new()));
let output_token_ids: Vec<_> = generator
.with_prompt(&prompt)
.with_logits_filter(DoubleIndexFilter {
prev_tokens: prev_tokens.clone(),
})
.take(expected_token_ids.len())
.map(|id| id.expect("generation failed"))
.collect();

assert_eq!(output_token_ids, [0, 2, 4, 6]);
assert_eq!(prev_tokens.borrow().as_slice(), [5, 6, 7, 0, 2, 4]);

Ok(())
}

#[test]
fn test_run_options() -> Result<(), Box<dyn Error>> {
let params = TransformerParams::default();
Expand Down
1 change: 1 addition & 0 deletions rten-generate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! [rten]: https://github.com/robertknight/rten
//! [rten-examples]: https://github.com/robertknight/rten/tree/main/rten-examples

pub mod filter;
pub mod generator;
pub mod metrics;
pub mod model;
Expand Down

0 comments on commit 77493b5

Please sign in to comment.