Skip to content

Commit

Permalink
Merge pull request #394 from robertknight/token-id-filter
Browse files Browse the repository at this point in the history
Add `token_id_filter` function to create logit filter from predicate
  • Loading branch information
robertknight authored Oct 27, 2024
2 parents 77493b5 + 50522b7 commit 63634b8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 6 deletions.
65 changes: 63 additions & 2 deletions rten-generate/src/filter.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
//! Filters for processing model outputs prior to sampling.
//!
//! This module defines the [`LogitsFilter`] trait implemented by all filters,
//! plus convenience functions to simplify implementing filters.

use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};

use crate::generator::TokenId;

/// Filter which modifies the output logits from a model.
///
/// The filter is applied to the model outputs before a token is sampled.
Expand All @@ -11,6 +17,61 @@ pub trait LogitsFilter {
/// If this method returns `None`, the input logits are passed unmodified
/// to the sampler. `prev_tokens` contains the previously sampled tokens,
/// including the prompt.
fn filter(&self, logits: NdTensorView<f32, 1>, prev_tokens: &[u32])
-> Option<NdTensor<f32, 1>>;
fn filter(
&self,
logits: NdTensorView<f32, 1>,
prev_tokens: &[TokenId],
) -> Option<NdTensor<f32, 1>>;
}

struct TokenIdFilter<F: Fn(TokenId) -> bool> {
predicate: F,
}

impl<F: Fn(TokenId) -> bool> LogitsFilter for TokenIdFilter<F> {
fn filter(
&self,
logits: NdTensorView<f32, 1>,
_prev_tokens: &[TokenId],
) -> Option<NdTensor<f32, 1>> {
Some(NdTensor::from_fn(logits.shape(), |[i]| {
let token_id = i as TokenId;
if (self.predicate)(token_id) {
logits[[i]]
} else {
f32::NEG_INFINITY
}
}))
}
}

/// Create a filter which suppresses all tokens that do not match a predicate by
/// setting the value to `f32::NEG_INFINITY`.
pub fn token_id_filter<F: Fn(TokenId) -> bool>(predicate: F) -> impl LogitsFilter {
TokenIdFilter { predicate }
}

#[cfg(test)]
mod tests {
use rten_tensor::prelude::*;
use rten_tensor::NdTensor;

use super::{token_id_filter, LogitsFilter};

#[test]
fn test_token_id_filter() {
let logits = NdTensor::from([0., 1., 2., 3., 4.]);
let filter = token_id_filter(|id| id % 2 == 0);
let output = filter.filter(logits.view(), &[]);
assert_eq!(
output,
Some(NdTensor::from([
0.,
f32::NEG_INFINITY,
2.,
f32::NEG_INFINITY,
4.
]))
);
}
}
8 changes: 4 additions & 4 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,10 @@ pub struct Generator<'a> {
logits_output: NodeId,

/// Filter used to modify logits before sampling.
logits_filter: Option<Box<dyn LogitsFilter>>,
logits_filter: Option<Box<dyn LogitsFilter + 'a>>,

/// Sampler used to get the next token ID from the output logits.
sampler: Box<dyn Sampler>,
sampler: Box<dyn Sampler + 'a>,

/// Previously sampled tokens. These are retained for conditional filtering
/// and sampling.
Expand Down Expand Up @@ -536,13 +536,13 @@ impl<'a> Generator<'a> {

/// Set the filter used to process model output logits before passing them
/// to the sampler to select a token ID.
pub fn with_logits_filter<F: LogitsFilter + 'static>(mut self, filter: F) -> Self {
pub fn with_logits_filter<F: LogitsFilter + 'a>(mut self, filter: F) -> Self {
self.logits_filter = Some(Box::new(filter));
self
}

/// Set the sampler used to sample the next token ID from the output logits.
pub fn with_sampler<S: Sampler + 'static>(mut self, sampler: S) -> Self {
pub fn with_sampler<S: Sampler + 'a>(mut self, sampler: S) -> Self {
self.sampler = Box::new(sampler);
self
}
Expand Down

0 comments on commit 63634b8

Please sign in to comment.