From 132cb4172b60395a4f3e163e8cf752a7e559e896 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 25 Nov 2024 08:41:30 -0500 Subject: [PATCH] Add prompt batching support --- mistralrs-core/src/pipeline/mod.rs | 62 +++++++++++++++---------- mistralrs-core/src/pipeline/response.rs | 8 ++-- mistralrs-core/src/response.rs | 18 +++++-- mistralrs-core/src/sequence.rs | 15 ++++-- mistralrs/examples/perplexity/main.rs | 34 +++++++++----- mistralrs/src/model.rs | 11 +++-- 6 files changed, 96 insertions(+), 52 deletions(-) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 880232811..e1e09e405 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -327,9 +327,20 @@ pub trait Pipeline: ); let mut logits = vec![None; input_seqs.len()]; + let prompt_batchsize = self + .get_metadata() + .prompt_batchsize + .map(NonZeroUsize::get) + .unwrap_or(1); + let len_inputs = input_seqs + .iter() + .map(|seq| (seq.get_toks().len() + prompt_batchsize - 1) / prompt_batchsize) + .max() + .unwrap(); + let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()]; let mut exec_duration = Duration::ZERO; - for (i, inputs) in inputs_iter.enumerate() { + for (i, inputs) in inputs_iter.into_iter().enumerate() { let InputProcessorOutput { inputs, seq_indices, @@ -399,6 +410,9 @@ pub trait Pipeline: exec_duration += end.duration_since(start); for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() { + if let ForwardInputsResult::RawLogits { logits } = &raw_logits { + raw_out_logits[seq_idx][i] = Some(logits.i(logit_idx)?); + } logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?); } } @@ -432,17 +446,10 @@ pub trait Pipeline: ForwardInputsResult::RawLogits { .. } => { response::send_raw_responses( input_seqs, - logits - .iter() - .map(|r| { - #[allow(irrefutable_let_patterns)] - let ForwardInputsResult::RawLogits { logits } = r - else { - unreachable!("All results must have same type") - }; - logits.clone() - }) - .collect::>(), + raw_out_logits + .into_iter() + .map(|raw| raw.into_iter().flatten().collect::>()) + .collect(), ) .await?; } @@ -523,9 +530,20 @@ pub trait Pipeline: ); let mut logits = vec![None; input_seqs.len()]; + let prompt_batchsize = self + .get_metadata() + .prompt_batchsize + .map(NonZeroUsize::get) + .unwrap_or(1); + let len_inputs = input_seqs + .iter() + .map(|seq| (seq.get_toks().len() + prompt_batchsize - 1) / prompt_batchsize) + .max() + .unwrap(); + let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()]; let mut exec_duration = Duration::ZERO; - for inputs in inputs_iter { + for (i, inputs) in inputs_iter.into_iter().enumerate() { let InputProcessorOutput { inputs, seq_indices, @@ -537,6 +555,9 @@ pub trait Pipeline: exec_duration += end.duration_since(start); for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() { + if let ForwardInputsResult::RawLogits { logits } = &raw_logits { + raw_out_logits[seq_idx][i] = Some(logits.i(logit_idx)?); + } logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?); } } @@ -554,17 +575,10 @@ pub trait Pipeline: ForwardInputsResult::RawLogits { .. } => { response::send_raw_responses( input_seqs, - logits - .iter() - .map(|r| { - #[allow(irrefutable_let_patterns)] - let ForwardInputsResult::RawLogits { logits } = r - else { - unreachable!("All results must have same type") - }; - logits.clone() - }) - .collect::>(), + raw_out_logits + .into_iter() + .map(|raw| raw.into_iter().flatten().collect::>()) + .collect(), ) .await?; } diff --git a/mistralrs-core/src/pipeline/response.rs b/mistralrs-core/src/pipeline/response.rs index 7442a4860..ee0e5e993 100644 --- a/mistralrs-core/src/pipeline/response.rs +++ b/mistralrs-core/src/pipeline/response.rs @@ -72,10 +72,10 @@ pub async fn send_image_responses( pub async fn send_raw_responses( input_seqs: &mut [&mut Sequence], - logits: Vec, + logits_chunks: Vec>, ) -> candle_core::Result<()> { - let logits = if logits.len() == 1 { - logits[0].clone() + let logits_chunks = if logits_chunks.len() == 1 { + logits_chunks[0].clone() } else { candle_core::bail!("Raw response only supports batch size of 1."); }; @@ -83,7 +83,7 @@ pub async fn send_raw_responses( let seq = &mut *input_seqs[0]; - seq.add_raw_choice_to_group(logits); + seq.add_raw_choice_to_group(logits_chunks); let group = seq.get_mut_group(); group diff --git a/mistralrs-core/src/response.rs b/mistralrs-core/src/response.rs index d51b5b009..d3c9f22fa 100644 --- a/mistralrs-core/src/response.rs +++ b/mistralrs-core/src/response.rs @@ -240,7 +240,10 @@ pub enum Response { // Image generation ImageGeneration(ImageGenerationResponse), // Raw - Raw { logits: Tensor, tokens: Vec }, + Raw { + logits_chunks: Vec, + tokens: Vec, + }, } #[derive(Debug, Clone)] @@ -254,7 +257,10 @@ pub enum ResponseOk { // Image generation ImageGeneration(ImageGenerationResponse), // Raw - Raw { logits: Tensor, tokens: Vec }, + Raw { + logits_chunks: Vec, + tokens: Vec, + }, } pub enum ResponseErr { @@ -317,7 +323,13 @@ impl Response { Err(Box::new(ResponseErr::CompletionModelError(e, x))) } Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)), - Self::Raw { logits, tokens } => Ok(ResponseOk::Raw { logits, tokens }), + Self::Raw { + logits_chunks, + tokens, + } => Ok(ResponseOk::Raw { + logits_chunks, + tokens, + }), } } } diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index 699bcf3c0..c97d906b4 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -711,10 +711,10 @@ impl Sequence { self.update_time_info(); } - pub fn add_raw_choice_to_group(&self, logits: Tensor) { + pub fn add_raw_choice_to_group(&self, logit_chunks: Vec) { get_mut_group!(self) .raw_choices - .push((logits, self.tokens.clone())); + .push((logit_chunks, self.tokens.clone())); self.update_time_info(); } @@ -786,7 +786,7 @@ pub struct SequenceGroup { pub total_completion_time: u128, choices: Vec, image_choices: Vec, - raw_choices: Vec<(Tensor, Vec)>, + raw_choices: Vec<(Vec, Vec)>, completion_choices: Vec<(f32, CompletionChoice)>, pub chat_streaming_chunks: Vec, pub completion_streaming_chunks: Vec, @@ -872,8 +872,13 @@ impl SequenceGroup { ) -> Result<(), SendError> { if self.raw_choices.len() == self.n_choices { assert_eq!(self.raw_choices.len(), 1); - let (logits, tokens) = self.raw_choices[0].clone(); - sender.send(Response::Raw { logits, tokens }).await?; + let (logits_chunks, tokens) = self.raw_choices[0].clone(); + sender + .send(Response::Raw { + logits_chunks, + tokens, + }) + .await?; } Ok(()) diff --git a/mistralrs/examples/perplexity/main.rs b/mistralrs/examples/perplexity/main.rs index ff5cfc779..c95c989f4 100644 --- a/mistralrs/examples/perplexity/main.rs +++ b/mistralrs/examples/perplexity/main.rs @@ -1,4 +1,4 @@ -use std::fs::read_to_string; +use std::{fs::read_to_string, num::NonZeroUsize}; use anyhow::Result; use clap::Parser; @@ -34,8 +34,10 @@ async fn main() -> Result<()> { None }; + let prompt_batchsize = 2048; let mut model_builder = TextModelBuilder::new(&args.model_id) .with_logging() + .with_prompt_batchsize(NonZeroUsize::new(prompt_batchsize).unwrap()) .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?; if let Some(quant) = quant { model_builder = model_builder.with_isq(quant); @@ -47,18 +49,26 @@ async fn main() -> Result<()> { let (logits, tokens) = model.send_raw_chat_request(messages).await?; - // Upcast to float if we need to compute the loss to avoid potential precision issues - let logits = logits.to_device(&Device::Cpu)?.to_dtype(DType::F32)?; - // Shift so that tokens < n predict n - let shift_logits = logits.narrow(0, 0, logits.dim(0)? - 1)?.contiguous()?; - let shift_labels = Tensor::from_slice(&tokens[1..], (tokens.len() - 1,), &Device::Cpu)?; + for (i, (logits, tokens)) in logits + .into_iter() + .zip(tokens.chunks(prompt_batchsize)) + .enumerate() + { + // Upcast to float if we need to compute the loss to avoid potential precision issues + let logits = logits.to_device(&Device::Cpu)?.to_dtype(DType::F32)?; + // Shift so that tokens < n predict n + let shift_logits = logits.narrow(0, 0, logits.dim(0)? - 1)?.contiguous()?; + let shift_labels = Tensor::from_slice(&tokens[1..], (tokens.len() - 1,), &Device::Cpu)?; - let loss_fct = cross_entropy_loss(&shift_logits, &shift_labels)?; - let perplexity = loss_fct.exp()?.to_scalar::()?; - println!( - "Perplexity for `{}`, ISQ `{:?}`: {perplexity}", - args.file, quant - ); + let loss_fct = cross_entropy_loss(&shift_logits, &shift_labels)?; + let perplexity = loss_fct.exp()?.to_scalar::()?; + println!( + "Chunk {i} ({} tokens): Perplexity for `{}`, ISQ `{:?}`: {perplexity}", + tokens.len(), + args.file, + quant + ); + } Ok(()) } diff --git a/mistralrs/src/model.rs b/mistralrs/src/model.rs index 0ace48f43..91c75dafe 100644 --- a/mistralrs/src/model.rs +++ b/mistralrs/src/model.rs @@ -95,11 +95,11 @@ impl Model { /// Generate with the model, returning raw logits of the first token generated. /// - /// Returns the logits and the tokens. + /// Returns the chunks of the logits (1 or more, determined by prompt batchsize) and the tokens. pub async fn send_raw_chat_request( &self, mut request: R, - ) -> anyhow::Result<(Tensor, Vec)> { + ) -> anyhow::Result<(Vec, Vec)> { let (tx, mut rx) = channel(1); let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() { @@ -125,7 +125,10 @@ impl Model { self.runner.get_sender()?.send(request).await?; - let ResponseOk::Raw { logits, tokens } = rx + let ResponseOk::Raw { + logits_chunks, + tokens, + } = rx .recv() .await .context("Channel was erroneously closed!")? @@ -134,7 +137,7 @@ impl Model { anyhow::bail!("Got unexpected response type.") }; - Ok((logits, tokens)) + Ok((logits_chunks, tokens)) } pub async fn generate_image(