Skip to content

Commit

Permalink
Add prompt batching support
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 25, 2024
1 parent b6467c2 commit 132cb41
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 52 deletions.
62 changes: 38 additions & 24 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,20 @@ pub trait Pipeline:
);

let mut logits = vec![None; input_seqs.len()];
let prompt_batchsize = self
.get_metadata()
.prompt_batchsize
.map(NonZeroUsize::get)
.unwrap_or(1);
let len_inputs = input_seqs
.iter()
.map(|seq| (seq.get_toks().len() + prompt_batchsize - 1) / prompt_batchsize)
.max()
.unwrap();
let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];

let mut exec_duration = Duration::ZERO;
for (i, inputs) in inputs_iter.enumerate() {
for (i, inputs) in inputs_iter.into_iter().enumerate() {
let InputProcessorOutput {
inputs,
seq_indices,
Expand Down Expand Up @@ -399,6 +410,9 @@ pub trait Pipeline:
exec_duration += end.duration_since(start);

for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
raw_out_logits[seq_idx][i] = Some(logits.i(logit_idx)?);
}
logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
}
}
Expand Down Expand Up @@ -432,17 +446,10 @@ pub trait Pipeline:
ForwardInputsResult::RawLogits { .. } => {
response::send_raw_responses(
input_seqs,
logits
.iter()
.map(|r| {
#[allow(irrefutable_let_patterns)]
let ForwardInputsResult::RawLogits { logits } = r
else {
unreachable!("All results must have same type")
};
logits.clone()
})
.collect::<Vec<_>>(),
raw_out_logits
.into_iter()
.map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
.collect(),
)
.await?;
}
Expand Down Expand Up @@ -523,9 +530,20 @@ pub trait Pipeline:
);

let mut logits = vec![None; input_seqs.len()];
let prompt_batchsize = self
.get_metadata()
.prompt_batchsize
.map(NonZeroUsize::get)
.unwrap_or(1);
let len_inputs = input_seqs
.iter()
.map(|seq| (seq.get_toks().len() + prompt_batchsize - 1) / prompt_batchsize)
.max()
.unwrap();
let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];

let mut exec_duration = Duration::ZERO;
for inputs in inputs_iter {
for (i, inputs) in inputs_iter.into_iter().enumerate() {
let InputProcessorOutput {
inputs,
seq_indices,
Expand All @@ -537,6 +555,9 @@ pub trait Pipeline:
exec_duration += end.duration_since(start);

for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
raw_out_logits[seq_idx][i] = Some(logits.i(logit_idx)?);
}
logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
}
}
Expand All @@ -554,17 +575,10 @@ pub trait Pipeline:
ForwardInputsResult::RawLogits { .. } => {
response::send_raw_responses(
input_seqs,
logits
.iter()
.map(|r| {
#[allow(irrefutable_let_patterns)]
let ForwardInputsResult::RawLogits { logits } = r
else {
unreachable!("All results must have same type")
};
logits.clone()
})
.collect::<Vec<_>>(),
raw_out_logits
.into_iter()
.map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
.collect(),
)
.await?;
}
Expand Down
8 changes: 4 additions & 4 deletions mistralrs-core/src/pipeline/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,18 @@ pub async fn send_image_responses(

pub async fn send_raw_responses(
input_seqs: &mut [&mut Sequence],
logits: Vec<Tensor>,
logits_chunks: Vec<Vec<Tensor>>,
) -> candle_core::Result<()> {
let logits = if logits.len() == 1 {
logits[0].clone()
let logits_chunks = if logits_chunks.len() == 1 {
logits_chunks[0].clone()
} else {
candle_core::bail!("Raw response only supports batch size of 1.");
};
assert_eq!(input_seqs.len(), 1);

let seq = &mut *input_seqs[0];

seq.add_raw_choice_to_group(logits);
seq.add_raw_choice_to_group(logits_chunks);

let group = seq.get_mut_group();
group
Expand Down
18 changes: 15 additions & 3 deletions mistralrs-core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ pub enum Response {
// Image generation
ImageGeneration(ImageGenerationResponse),
// Raw
Raw { logits: Tensor, tokens: Vec<u32> },
Raw {
logits_chunks: Vec<Tensor>,
tokens: Vec<u32>,
},
}

#[derive(Debug, Clone)]
Expand All @@ -254,7 +257,10 @@ pub enum ResponseOk {
// Image generation
ImageGeneration(ImageGenerationResponse),
// Raw
Raw { logits: Tensor, tokens: Vec<u32> },
Raw {
logits_chunks: Vec<Tensor>,
tokens: Vec<u32>,
},
}

pub enum ResponseErr {
Expand Down Expand Up @@ -317,7 +323,13 @@ impl Response {
Err(Box::new(ResponseErr::CompletionModelError(e, x)))
}
Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
Self::Raw { logits, tokens } => Ok(ResponseOk::Raw { logits, tokens }),
Self::Raw {
logits_chunks,
tokens,
} => Ok(ResponseOk::Raw {
logits_chunks,
tokens,
}),
}
}
}
15 changes: 10 additions & 5 deletions mistralrs-core/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,10 +711,10 @@ impl Sequence {
self.update_time_info();
}

pub fn add_raw_choice_to_group(&self, logits: Tensor) {
pub fn add_raw_choice_to_group(&self, logit_chunks: Vec<Tensor>) {
get_mut_group!(self)
.raw_choices
.push((logits, self.tokens.clone()));
.push((logit_chunks, self.tokens.clone()));
self.update_time_info();
}

Expand Down Expand Up @@ -786,7 +786,7 @@ pub struct SequenceGroup {
pub total_completion_time: u128,
choices: Vec<Choice>,
image_choices: Vec<ImageChoice>,
raw_choices: Vec<(Tensor, Vec<u32>)>,
raw_choices: Vec<(Vec<Tensor>, Vec<u32>)>,
completion_choices: Vec<(f32, CompletionChoice)>,
pub chat_streaming_chunks: Vec<ChunkChoice>,
pub completion_streaming_chunks: Vec<CompletionChunkChoice>,
Expand Down Expand Up @@ -872,8 +872,13 @@ impl SequenceGroup {
) -> Result<(), SendError<Response>> {
if self.raw_choices.len() == self.n_choices {
assert_eq!(self.raw_choices.len(), 1);
let (logits, tokens) = self.raw_choices[0].clone();
sender.send(Response::Raw { logits, tokens }).await?;
let (logits_chunks, tokens) = self.raw_choices[0].clone();
sender
.send(Response::Raw {
logits_chunks,
tokens,
})
.await?;
}

Ok(())
Expand Down
34 changes: 22 additions & 12 deletions mistralrs/examples/perplexity/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fs::read_to_string;
use std::{fs::read_to_string, num::NonZeroUsize};

use anyhow::Result;
use clap::Parser;
Expand Down Expand Up @@ -34,8 +34,10 @@ async fn main() -> Result<()> {
None
};

let prompt_batchsize = 2048;
let mut model_builder = TextModelBuilder::new(&args.model_id)
.with_logging()
.with_prompt_batchsize(NonZeroUsize::new(prompt_batchsize).unwrap())
.with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?;
if let Some(quant) = quant {
model_builder = model_builder.with_isq(quant);
Expand All @@ -47,18 +49,26 @@ async fn main() -> Result<()> {

let (logits, tokens) = model.send_raw_chat_request(messages).await?;

// Upcast to float if we need to compute the loss to avoid potential precision issues
let logits = logits.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
// Shift so that tokens < n predict n
let shift_logits = logits.narrow(0, 0, logits.dim(0)? - 1)?.contiguous()?;
let shift_labels = Tensor::from_slice(&tokens[1..], (tokens.len() - 1,), &Device::Cpu)?;
for (i, (logits, tokens)) in logits
.into_iter()
.zip(tokens.chunks(prompt_batchsize))
.enumerate()
{
// Upcast to float if we need to compute the loss to avoid potential precision issues
let logits = logits.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
// Shift so that tokens < n predict n
let shift_logits = logits.narrow(0, 0, logits.dim(0)? - 1)?.contiguous()?;
let shift_labels = Tensor::from_slice(&tokens[1..], (tokens.len() - 1,), &Device::Cpu)?;

let loss_fct = cross_entropy_loss(&shift_logits, &shift_labels)?;
let perplexity = loss_fct.exp()?.to_scalar::<f32>()?;
println!(
"Perplexity for `{}`, ISQ `{:?}`: {perplexity}",
args.file, quant
);
let loss_fct = cross_entropy_loss(&shift_logits, &shift_labels)?;
let perplexity = loss_fct.exp()?.to_scalar::<f32>()?;
println!(
"Chunk {i} ({} tokens): Perplexity for `{}`, ISQ `{:?}`: {perplexity}",
tokens.len(),
args.file,
quant
);
}

Ok(())
}
11 changes: 7 additions & 4 deletions mistralrs/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ impl Model {

/// Generate with the model, returning raw logits of the first token generated.
///
/// Returns the logits and the tokens.
/// Returns the chunks of the logits (1 or more, determined by prompt batchsize) and the tokens.
pub async fn send_raw_chat_request<R: RequestLike>(
&self,
mut request: R,
) -> anyhow::Result<(Tensor, Vec<u32>)> {
) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
let (tx, mut rx) = channel(1);

let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
Expand All @@ -125,7 +125,10 @@ impl Model {

self.runner.get_sender()?.send(request).await?;

let ResponseOk::Raw { logits, tokens } = rx
let ResponseOk::Raw {
logits_chunks,
tokens,
} = rx
.recv()
.await
.context("Channel was erroneously closed!")?
Expand All @@ -134,7 +137,7 @@ impl Model {
anyhow::bail!("Got unexpected response type.")
};

Ok((logits, tokens))
Ok((logits_chunks, tokens))
}

pub async fn generate_image(
Expand Down

0 comments on commit 132cb41

Please sign in to comment.