Skip to content

Commit

Permalink
Fix some accuracy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 3, 2024
1 parent f9d6682 commit 679c788
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 141 deletions.
29 changes: 29 additions & 0 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,35 @@ impl Module for RmsNorm {
}
}

#[derive(Debug, Clone)]
pub struct F32RmsNorm {
w: Tensor,
eps: f64,
}

impl F32RmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
Ok(Self {
w: vb.get((size,), "weight")?,
eps,
})
}

pub fn weight(&self) -> &Tensor {
&self.w
}
}

impl Module for F32RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let initial_type = xs.dtype();
let mut xs = xs.to_dtype(DType::F32)?;
let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
}
}

#[derive(Debug, Clone)]
pub struct QRmsNorm {
eps: f64,
Expand Down
8 changes: 7 additions & 1 deletion mistralrs-core/src/utils/unvarbuilder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use candle_nn::{Conv2d, Embedding, LayerNorm, Linear};
use itertools::Itertools;
use mistralrs_quant::QuantMethod;

use crate::layers::{FusedBiasLinear, QLinear, RmsNorm};
use crate::layers::{F32RmsNorm, FusedBiasLinear, QLinear, RmsNorm};

pub trait ToTensors {
/// Tensor names to tensors
Expand All @@ -27,6 +27,12 @@ impl ToTensors for RmsNorm {
}
}

impl ToTensors for F32RmsNorm {
fn to_tensors(&self) -> HashMap<String, Tensor> {
HashMap::from_iter([("weight".to_string(), self.weight().clone())])
}
}

impl ToTensors for LayerNorm {
fn to_tensors(&self) -> HashMap<String, Tensor> {
HashMap::from_iter([
Expand Down
64 changes: 17 additions & 47 deletions mistralrs-core/src/vision_models/mllama/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,22 @@

use std::{collections::HashMap, sync::Arc};

use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_core::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{embedding, Activation, Embedding, Module, VarBuilder};
use mistralrs_quant::{linear_no_bias, QuantMethod, QuantMethodConfig, UnquantLinear};

use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
layers::{CausalMasker, Llama3RotaryEmbedding, Sdpa},
layers::{CausalMasker, F32RmsNorm, Llama3RotaryEmbedding, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata},
pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata},
utils::unvarbuilder::{ToTensors, UnVarBuilder},
utils::unvarbuilder::UnVarBuilder,
};

use super::config::MLlamaTextConfig;

struct MLlamaRmsNorm {
w: Tensor,
eps: f64,
}

impl MLlamaRmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
Ok(Self {
w: vb.get((size,), "weight")?,
eps,
})
}
}

impl Module for MLlamaRmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let initial_type = xs.dtype();
let mut xs = xs.to_dtype(DType::F32)?;
let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
}
}

impl ToTensors for MLlamaRmsNorm {
fn to_tensors(&self) -> HashMap<String, Tensor> {
HashMap::from_iter([("weight".to_string(), self.w.clone())])
}
}

struct MLlamaTextMlp {
gate_proj: Arc<dyn QuantMethod>,
up_proj: Arc<dyn QuantMethod>,
Expand Down Expand Up @@ -234,8 +204,8 @@ impl MLlamaTextSelfAttention {
struct MLlamaSelfAttentionDecoderLayer {
attn: MLlamaTextSelfAttention,
mlp: MLlamaTextMlp,
input_layernorm: MLlamaRmsNorm,
post_attention_layernorm: MLlamaRmsNorm,
input_layernorm: F32RmsNorm,
post_attention_layernorm: F32RmsNorm,
}

impl MLlamaSelfAttentionDecoderLayer {
Expand All @@ -248,12 +218,12 @@ impl MLlamaSelfAttentionDecoderLayer {
loading_isq: bool,
) -> Result<Self> {
let mlp = MLlamaTextMlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
let input_layernorm = MLlamaRmsNorm::new(
let input_layernorm = F32RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = MLlamaRmsNorm::new(
let post_attention_layernorm = F32RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
Expand Down Expand Up @@ -306,8 +276,8 @@ struct MLlamaTextCrossAttention {
k_proj: Arc<dyn QuantMethod>,
v_proj: Arc<dyn QuantMethod>,
o_proj: Arc<dyn QuantMethod>,
q_norm: MLlamaRmsNorm,
k_norm: MLlamaRmsNorm,
q_norm: F32RmsNorm,
k_norm: F32RmsNorm,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
Expand Down Expand Up @@ -346,12 +316,12 @@ impl MLlamaTextCrossAttention {
&cfg.quantization_config,
vb.pp("o_proj"),
)?,
q_norm: MLlamaRmsNorm::new(
q_norm: F32RmsNorm::new(
cfg.head_dim(),
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("q_norm"), false),
)?,
k_norm: MLlamaRmsNorm::new(
k_norm: F32RmsNorm::new(
cfg.head_dim(),
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("k_norm"), false),
Expand Down Expand Up @@ -460,8 +430,8 @@ struct MLlamaCrossAttentionDecoderLayer {
attn_gate: Tensor,
mlp: MLlamaTextMlp,
mlp_gate: Tensor,
input_layernorm: MLlamaRmsNorm,
post_attention_layernorm: MLlamaRmsNorm,
input_layernorm: F32RmsNorm,
post_attention_layernorm: F32RmsNorm,
}

impl MLlamaCrossAttentionDecoderLayer {
Expand All @@ -473,12 +443,12 @@ impl MLlamaCrossAttentionDecoderLayer {
loading_isq: bool,
) -> Result<Self> {
let mlp = MLlamaTextMlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
let input_layernorm = MLlamaRmsNorm::new(
let input_layernorm = F32RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = MLlamaRmsNorm::new(
let post_attention_layernorm = F32RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
Expand Down Expand Up @@ -543,7 +513,7 @@ enum MLlamaDecoderLayer {
pub(super) struct MLlamaTextModel {
embed_tokens: Embedding,
lm_head: Arc<dyn QuantMethod>,
norm: MLlamaRmsNorm,
norm: F32RmsNorm,
layers: Vec<MLlamaDecoderLayer>,
pub(crate) cfg: ModelConfigMetadata,
pub(crate) cache: Cache,
Expand Down Expand Up @@ -589,7 +559,7 @@ impl MLlamaTextModel {

let vb = vb.pp("model");

let norm = MLlamaRmsNorm::new(
let norm = F32RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb.pp("norm"), false),
Expand Down
25 changes: 8 additions & 17 deletions mistralrs-core/src/vision_models/qwen2vl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,22 +296,15 @@ impl Qwen2VLModel {
continuous_vid_pad: Vec<Vec<(usize, usize)>>,
seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let attention_mask = {
let cache = self.text.cache.lock();
CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
self.text.cfg.sliding_window,
self.text.norm.weight().dtype(),
self.text.cfg.num_attn_heads,
)?
};
let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias(
input_ids,
&seqlen_offsets as &dyn PastKvLenCache,
self.text.cfg.sliding_window,
self.text.dtype,
self.text.cfg.num_attn_heads,
)?;

let input_embeds = if pixel_values.is_some() || pixel_values_videos.is_some() {
let mut xs = self.text.embed_tokens(input_ids)?;
Expand Down Expand Up @@ -419,7 +412,6 @@ impl Qwen2VLModel {
attention_mask.as_ref(),
&position_ids,
context_lens,
metadata,
flash_params,
)?;
Ok(out)
Expand All @@ -445,7 +437,7 @@ impl VisionModel for Qwen2VLModel {
context_lens: Vec<(usize, usize)>,
_position_ids: Vec<usize>,
model_specific_args: Box<dyn Any>,
metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
_metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let Qwen2VLVisionSpecificArgs {
Expand Down Expand Up @@ -478,7 +470,6 @@ impl VisionModel for Qwen2VLModel {
continuous_vid_pad,
seqlen_offsets,
context_lens,
metadata,
flash_params,
)
}
Expand Down
Loading

0 comments on commit 679c788

Please sign in to comment.