Skip to content

Commit

Permalink
Add some Metal kernels for HQQ dequant (#907)
Browse files Browse the repository at this point in the history
* Add some metal kernels for HQQ dequant

* Clippy and format

* Add metal bitwise ops

* Oops

* Fix test

* Update gitignore
  • Loading branch information
EricLBuehler authored Nov 11, 2024
1 parent 3fdf496 commit 10dc437
Show file tree
Hide file tree
Showing 16 changed files with 1,216 additions and 41 deletions.
Binary file removed .DS_Store
Binary file not shown.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/target
.ruff_cache
.vscode
*.a
*.a
.DS_Store
12 changes: 7 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "2e17ebd" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "2e17ebd" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "11495ab" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "11495ab" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "2e17ebd", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "11495ab", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down
4 changes: 3 additions & 1 deletion mistralrs-quant/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ rayon.workspace = true
byteorder = "1.5.0"
float8.workspace = true
once_cell.workspace = true
metal = { version = "0.27.0", features = ["mps"], optional = true }
thiserror = "1"

[features]
cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda"]
metal = ["candle-core/metal", "candle-nn/metal"]
metal = ["candle-core/metal", "candle-nn/metal", "dep:metal"]

[build-dependencies]
bindgen_cuda = { version = "0.1.5", optional = true }
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "metal")]
use candle_core::{backend::BackendStorage, DType};
use candle_core::{CpuStorage, CustomOp3, Layout, Result, Shape, WithDType};

/*
Expand All @@ -9,8 +11,8 @@ pub(crate) struct Dequant8Bit {
}

impl Dequant8Bit {
fn dequantize<T: WithDType>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
let mut out = Vec::with_capacity(w.len());
fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
let mut out = vec![T::default(); w.len()];
for (i, w) in w.iter().enumerate() {
let j = i % self.w;
out[i] = (T::from_f64(*w as f64) - z[j]) * s[j];
Expand Down Expand Up @@ -54,6 +56,54 @@ impl CustomOp3 for Dequant8Bit {
(_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
}
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
w: &candle_core::MetalStorage,
l_w: &Layout,
s: &candle_core::MetalStorage,
l_s: &Layout,
z: &candle_core::MetalStorage,
l_z: &Layout,
) -> Result<(candle_core::MetalStorage, Shape)> {
if w.dtype() != DType::U8 {
candle_core::bail!("Weight must be u8, HQQ dequant 8-bit");
};
if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
candle_core::bail!("All inputs must be contiguous");
}

let command_buffer = w.device().command_buffer()?;
command_buffer.set_label("dequant-8bit");

let device = w.device();

let out_shape = Shape::from_dims(&[self.h, self.w]);

let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-8bit")?;

crate::metal_kernels::call_dequant_8bit(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
s.dtype(),
w.buffer(),
s.buffer(),
z.buffer(),
self.h as u32,
self.w as u32,
&output,
)
.map_err(candle_core::Error::wrap)?;

let newstorage = candle_core::MetalStorage::new(
output,
device.clone(),
out_shape.elem_count(),
s.dtype(),
);
Ok((newstorage, out_shape))
}
}

/*
Expand Down Expand Up @@ -115,6 +165,56 @@ impl CustomOp3 for Dequant4Bit {
(_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
}
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
w: &candle_core::MetalStorage,
l_w: &Layout,
s: &candle_core::MetalStorage,
l_s: &Layout,
z: &candle_core::MetalStorage,
l_z: &Layout,
) -> Result<(candle_core::MetalStorage, Shape)> {
const PACK_FACTOR: usize = 2;

if w.dtype() != DType::U8 {
candle_core::bail!("Weight must be u8, HQQ dequant 4-bit");
};
if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
candle_core::bail!("All inputs must be contiguous");
}

let command_buffer = w.device().command_buffer()?;
command_buffer.set_label("dequant-4bit");

let device = w.device();

let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);

let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-4bit")?;

crate::metal_kernels::call_dequant_4bit(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
s.dtype(),
w.buffer(),
s.buffer(),
z.buffer(),
self.h as u32,
self.w as u32,
&output,
)
.map_err(candle_core::Error::wrap)?;

let newstorage = candle_core::MetalStorage::new(
output,
device.clone(),
out_shape.elem_count(),
s.dtype(),
);
Ok((newstorage, out_shape))
}
}

/*
Expand Down Expand Up @@ -178,6 +278,56 @@ impl CustomOp3 for Dequant2Bit {
(_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
}
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
w: &candle_core::MetalStorage,
l_w: &Layout,
s: &candle_core::MetalStorage,
l_s: &Layout,
z: &candle_core::MetalStorage,
l_z: &Layout,
) -> Result<(candle_core::MetalStorage, Shape)> {
const PACK_FACTOR: usize = 4;

if w.dtype() != DType::U8 {
candle_core::bail!("Weight must be u8, HQQ dequant 2-bit");
};
if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
candle_core::bail!("All inputs must be contiguous");
}

let command_buffer = w.device().command_buffer()?;
command_buffer.set_label("dequant-2bit");

let device = w.device();

let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);

let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-2bit")?;

crate::metal_kernels::call_dequant_2bit(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
s.dtype(),
w.buffer(),
s.buffer(),
z.buffer(),
self.h as u32,
self.w as u32,
&output,
)
.map_err(candle_core::Error::wrap)?;

let newstorage = candle_core::MetalStorage::new(
output,
device.clone(),
out_shape.elem_count(),
s.dtype(),
);
Ok((newstorage, out_shape))
}
}

/*
Expand Down Expand Up @@ -245,6 +395,56 @@ impl CustomOp3 for Dequant1Bit {
(_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
}
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
w: &candle_core::MetalStorage,
l_w: &Layout,
s: &candle_core::MetalStorage,
l_s: &Layout,
z: &candle_core::MetalStorage,
l_z: &Layout,
) -> Result<(candle_core::MetalStorage, Shape)> {
const PACK_FACTOR: usize = 8;

if w.dtype() != DType::U8 {
candle_core::bail!("Weight must be u8, HQQ dequant 1-bit");
};
if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
candle_core::bail!("All inputs must be contiguous");
}

let command_buffer = w.device().command_buffer()?;
command_buffer.set_label("dequant-1bit");

let device = w.device();

let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);

let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-1bit")?;

crate::metal_kernels::call_dequant_1bit(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
s.dtype(),
w.buffer(),
s.buffer(),
z.buffer(),
self.h as u32,
self.w as u32,
&output,
)
.map_err(candle_core::Error::wrap)?;

let newstorage = candle_core::MetalStorage::new(
output,
device.clone(),
out_shape.elem_count(),
s.dtype(),
);
Ok((newstorage, out_shape))
}
}

/*
Expand Down Expand Up @@ -314,4 +514,54 @@ impl CustomOp3 for Dequant3Bit {
(_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
}
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
w: &candle_core::MetalStorage,
l_w: &Layout,
s: &candle_core::MetalStorage,
l_s: &Layout,
z: &candle_core::MetalStorage,
l_z: &Layout,
) -> Result<(candle_core::MetalStorage, Shape)> {
const PACK_FACTOR: usize = 10;

if w.dtype() != DType::I32 {
candle_core::bail!("Weight must be i32, HQQ dequant 3-bit");
};
if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
candle_core::bail!("All inputs must be contiguous");
}

let command_buffer = w.device().command_buffer()?;
command_buffer.set_label("dequant-3bit");

let device = w.device();

let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);

let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-3bit")?;

crate::metal_kernels::call_dequant_3bit(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
s.dtype(),
w.buffer(),
s.buffer(),
z.buffer(),
self.h as u32,
self.w as u32,
&output,
)
.map_err(candle_core::Error::wrap)?;

let newstorage = candle_core::MetalStorage::new(
output,
device.clone(),
out_shape.elem_count(),
s.dtype(),
);
Ok((newstorage, out_shape))
}
}
Loading

0 comments on commit 10dc437

Please sign in to comment.