Skip to content

Commit

Permalink
Support samplers and logit filters with non-static lifetimes
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Oct 27, 2024
1 parent 77493b5 commit ed9fa13
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,10 @@ pub struct Generator<'a> {
logits_output: NodeId,

/// Filter used to modify logits before sampling.
logits_filter: Option<Box<dyn LogitsFilter>>,
logits_filter: Option<Box<dyn LogitsFilter + 'a>>,

/// Sampler used to get the next token ID from the output logits.
sampler: Box<dyn Sampler>,
sampler: Box<dyn Sampler + 'a>,

/// Previously sampled tokens. These are retained for conditional filtering
/// and sampling.
Expand Down Expand Up @@ -536,13 +536,13 @@ impl<'a> Generator<'a> {

/// Set the filter used to process model output logits before passing them
/// to the sampler to select a token ID.
pub fn with_logits_filter<F: LogitsFilter + 'static>(mut self, filter: F) -> Self {
pub fn with_logits_filter<F: LogitsFilter + 'a>(mut self, filter: F) -> Self {
self.logits_filter = Some(Box::new(filter));
self
}

/// Set the sampler used to sample the next token ID from the output logits.
pub fn with_sampler<S: Sampler + 'static>(mut self, sampler: S) -> Self {
pub fn with_sampler<S: Sampler + 'a>(mut self, sampler: S) -> Self {
self.sampler = Box::new(sampler);
self
}
Expand Down

0 comments on commit ed9fa13

Please sign in to comment.