Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention-fused softmax for Metal #908

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "11495ab" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "11495ab" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "77a6cc6" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "77a6cc6" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "11495ab", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "77a6cc6", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down
41 changes: 24 additions & 17 deletions mistralrs-core/src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,24 +92,31 @@ fn naive_sdpa(
head_dim: usize,
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
let mut att = MatMul.matmul_affine_div(
&q.contiguous()?,
&k.t()?.contiguous()?,
(head_dim as f64).sqrt(),
)?;
if let Some(softcap) = sdpa_params.softcap {
att = (att / softcap as f64)?;
att = att.tanh()?;
att = (att * softcap as f64)?;
}
if let Some(mask) = mask {
let mut att = MatMul.matmul(q, &k.t()?)?;
if let Some(softcap) = sdpa_params.softcap {
att = (att / softcap as f64)?;
att = att.tanh()?;
att = (att * softcap as f64)?;
}

let att = candle_nn::ops::attn_softmax_last_dim(&att, mask, 1. / (head_dim as f32).sqrt())?;
MatMul.matmul(&att, v)
} else {
let mut att = MatMul.matmul_affine_div(q, &k.t()?, (head_dim as f64).sqrt())?;
if let Some(softcap) = sdpa_params.softcap {
att = (att / softcap as f64)?;
att = att.tanh()?;
att = (att * softcap as f64)?;
}

let att = match mask {
Some(m) => att.broadcast_add(m)?,
None => att,
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
MatMul.matmul(&att, &v.contiguous()?)
let att = match mask {
Some(m) => att.broadcast_add(m)?,
None => att,
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
MatMul.matmul(&att, v)
}
}

pub struct SdpaParams {
Expand Down
124 changes: 81 additions & 43 deletions mistralrs-core/src/layers_masker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,31 +122,28 @@ impl CausalMasker {
return Ok(k_cache_1.dims()[2]);
}

pub fn make_causal_mask_as_attn_bias(
pub fn make_causal_mask_matrix(
&self,
input_ids: &Tensor,
cache: &dyn PastKvLenCache,
dtype: DType,
n_attn_heads: usize,
) -> Result<Option<Tensor>> {
let past_kv_len = cache.get_past_kv_len()?;
let (b_sz, tgt_len) = input_ids.dims2()?;
let (_b_sz, tgt_len) = input_ids.dims2()?;
if tgt_len == 1 {
return Ok(None);
}

let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
let mask = self
.make_mask(tgt_len, past_kv_len, input_ids.device())?
.to_dtype(DType::U8)?;
Some(mask)
};

let zero = Tensor::new(0.0f32, input_ids.device())?;
let causal_mask: Option<Result<Tensor>> = causal_mask.map(|mask| {
let mask =
mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?;
let mask = mask.broadcast_as((mask.dims()[0], mask.dims()[1]))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
let mask = masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
Expand All @@ -164,20 +161,19 @@ impl CausalMasker {
Ok(mask)
}

pub fn make_causal_mask_with_sliding_window_as_attn_bias(
pub fn make_sliding_window_causal_mask_matrix(
&self,
input_ids: &Tensor,
cache: &dyn PastKvLenCache,
sliding_window: Option<usize>,
dtype: DType,
n_attn_heads: usize,
) -> Result<Option<Tensor>> {
if sliding_window.is_none() {
return self.make_causal_mask_as_attn_bias(input_ids, cache, dtype, n_attn_heads);
return self.make_causal_mask_matrix(input_ids, cache, dtype);
}
let sliding_window = sliding_window.unwrap();
let past_kv_len = cache.get_past_kv_len()?;
let (b_sz, tgt_len) = input_ids.dims2()?;
let (_b_sz, tgt_len) = input_ids.dims2()?;
if tgt_len == 1 {
return Ok(None);
}
Expand All @@ -186,18 +182,15 @@ impl CausalMasker {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?
.to_dtype(DType::U8)?;

Some(mask)
};

let zero = Tensor::new(0.0f32, input_ids.device())?;
let causal_mask: Option<Result<Tensor>> = causal_mask.map(|mask| {
let mask =
mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?;
let mask = mask.broadcast_as((mask.dims()[0], mask.dims()[1]))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
let mask = masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
Expand All @@ -216,60 +209,105 @@ impl CausalMasker {
}

#[deprecated(
since = "0.1.10",
note = "use `make_causal_mask_as_attn_bias` instead! \
This is *not* compatible with `Sdpa`"
since = "0.3.3",
note = "use `make_causal_mask_matrix_as_attn_bias` instead. This is incompatible with `Sdpa`."
)]
pub fn make_causal_mask(
pub fn make_causal_mask_as_attn_bias(
&self,
input_ids: &Tensor,
cache: &[Option<(Tensor, Tensor)>],
cache: &dyn PastKvLenCache,
dtype: DType,
n_attn_heads: usize,
) -> Result<Option<Tensor>> {
let past_kv_len = self.calculate_past_kv_len(cache)?;
let past_kv_len = cache.get_past_kv_len()?;
let (b_sz, tgt_len) = input_ids.dims2()?;
if tgt_len == 1 {
return Ok(None);
}

let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;
let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;
Some(mask)
};

Ok(Some(mask))
let zero = Tensor::new(0.0f32, input_ids.device())?;
let causal_mask: Option<Result<Tensor>> = causal_mask.map(|mask| {
let mask =
mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
let mask = masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?;

Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
Some(mask?)
} else {
None
};
Ok(mask)
}

#[deprecated(
since = "0.1.10",
note = "use `make_causal_mask_with_sliding_window_as_attn_bias` instead! \
This is *not* compatible with `Sdpa`"
since = "0.3.3",
note = "use `make_causal_mask_matrix_with_sliding_window_as_attn_bias` instead. This is incompatible with `Sdpa`."
)]
pub fn make_causal_mask_with_sliding_window(
pub fn make_causal_mask_with_sliding_window_as_attn_bias(
&self,
input_ids: &Tensor,
cache: &[Option<(Tensor, Tensor)>],
cache: &dyn PastKvLenCache,
sliding_window: Option<usize>,
dtype: DType,
n_attn_heads: usize,
) -> Result<Option<Tensor>> {
if sliding_window.is_none() {
#[allow(deprecated)]
return self.make_causal_mask(input_ids, cache);
return self.make_causal_mask_as_attn_bias(input_ids, cache, dtype, n_attn_heads);
}
let sliding_window = sliding_window.unwrap();
let past_kv_len = self.calculate_past_kv_len(cache)?;
let past_kv_len = cache.get_past_kv_len()?;
let (b_sz, tgt_len) = input_ids.dims2()?;
if tgt_len == 1 {
return Ok(None);
}

let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;
let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

Ok(Some(mask))
Some(mask)
};

let zero = Tensor::new(0.0f32, input_ids.device())?;
let causal_mask: Option<Result<Tensor>> = causal_mask.map(|mask| {
let mask =
mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
let mask = masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?;

Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
Some(mask?)
} else {
None
};
Ok(mask)
}

pub fn apply_mask_one_and_zero(
Expand Down
3 changes: 1 addition & 2 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,13 @@ impl Model {
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
let mut cache = self.cache.lock();
let attention_mask = CausalMasker.make_causal_mask_as_attn_bias(
let attention_mask = CausalMasker.make_causal_mask_matrix(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
xs.dtype(),
self.layers[0].self_attn.num_heads,
)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
Expand Down
14 changes: 4 additions & 10 deletions mistralrs-core/src/models/gemma2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,20 +609,14 @@ impl Model {
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
let mut cache = self.cache.lock();
let attention_mask = CausalMasker.make_causal_mask_as_attn_bias(
let attention_mask =
CausalMasker.make_causal_mask_matrix(input_ids, &*cache, xs.dtype())?;
let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
input_ids,
&*cache,
Some(self.sliding_window),
xs.dtype(),
self.layers[0].self_attn.num_heads,
)?;
let sliding_attention_mask = CausalMasker
.make_causal_mask_with_sliding_window_as_attn_bias(
input_ids,
&*cache,
Some(self.sliding_window),
xs.dtype(),
self.layers[0].self_attn.num_heads,
)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
Expand Down
3 changes: 1 addition & 2 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,14 +411,13 @@ impl Llama {
) -> Result<Tensor> {
let mut x = self.wte.forward(input_ids)?;
let mut cache = self.kv_cache.lock();
let mask = CausalMasker.make_causal_mask_as_attn_bias(
let mask = CausalMasker.make_causal_mask_matrix(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
x.dtype(),
self.blocks[0].attn.num_attention_heads,
)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = self.mapper.map(x, block_idx)?;
Expand Down
Loading
Loading