diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index d19f099f71..b8c71ede44 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -1,7 +1,7 @@ //! Methods for backpropagation of gradients. use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp}; use crate::{Error, Result, Tensor, TensorId}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; // arg has been reduced to node via reduce_dims, expand it back to arg. // This has to handle keepdims. @@ -145,6 +145,11 @@ impl Tensor { } } Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes, + Op::QMatmul(_, lhs) => { + let (tg, nodes) = walk(lhs, nodes, already_seen); + track_grad |= tg; + nodes + } } } else { nodes @@ -162,6 +167,12 @@ impl Tensor { pub fn backward(&self) -> Result { let sorted_nodes = self.sorted_nodes(); + let mut filtered_node_ids = HashSet::new(); + for node in &sorted_nodes { + filtered_node_ids.insert(node.id()); + } + let is_tensor_grad_needed = + |tensor: &Tensor| -> bool { filtered_node_ids.contains(&tensor.id()) }; let mut grads = GradStore::new(); grads.insert(self, self.ones_like()?.contiguous()?); for node in sorted_nodes.iter() { @@ -181,56 +192,80 @@ impl Tensor { if let Some(op) = node.op() { match op { Op::Binary(lhs, rhs, BinaryOp::Add) => { - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&grad)?; + if is_tensor_grad_needed(lhs) { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&grad)?; + } + if is_tensor_grad_needed(rhs) { + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&grad)?; + } } Op::Binary(lhs, rhs, BinaryOp::Sub) => { - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.sub(&grad)?; + if is_tensor_grad_needed(lhs) { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&grad)?; + } + if is_tensor_grad_needed(rhs) { + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.sub(&grad)?; + } } Op::Binary(lhs, rhs, BinaryOp::Mul) => { - let lhs_grad = grad.mul(rhs)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; - let rhs_grad = grad.mul(lhs)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + if is_tensor_grad_needed(lhs) { + let lhs_grad = grad.mul(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + } + if is_tensor_grad_needed(rhs) { + let rhs_grad = grad.mul(lhs)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } } Op::Binary(lhs, rhs, BinaryOp::Div) => { - let lhs_grad = grad.div(rhs)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; - let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?; + if is_tensor_grad_needed(lhs) { + let lhs_grad = grad.div(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + } + if is_tensor_grad_needed(rhs) { + let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?; + } } Op::Binary(lhs, rhs, BinaryOp::Minimum) | Op::Binary(lhs, rhs, BinaryOp::Maximum) => { let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?; let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?; - // If both masks are 1 one the same point, we want to scale the - // gradient by 0.5 rather than 1. - let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + if is_tensor_grad_needed(lhs) { + // If both masks are 1 one the same point, we want to scale the + // gradient by 0.5 rather than 1. + let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + } - let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + if is_tensor_grad_needed(rhs) { + let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } } Op::WhereCond(pred, t, f) => { let zeros = grad.zeros_like()?; - let t_sum_grad = grads.or_insert(t)?; - let t_grad = pred.where_cond(&grad, &zeros)?; - *t_sum_grad = t_sum_grad.add(&t_grad)?; - let f_sum_grad = grads.or_insert(f)?; - let f_grad = pred.where_cond(&zeros, &grad)?; - *f_sum_grad = f_sum_grad.add(&f_grad)?; + if is_tensor_grad_needed(t) { + let t_sum_grad = grads.or_insert(t)?; + let t_grad = pred.where_cond(&grad, &zeros)?; + *t_sum_grad = t_sum_grad.add(&t_grad)?; + } + if is_tensor_grad_needed(f) { + let f_sum_grad = grads.or_insert(f)?; + let f_grad = pred.where_cond(&zeros, &grad)?; + *f_sum_grad = f_sum_grad.add(&f_grad)?; + } } Op::Conv1D { arg, @@ -239,37 +274,40 @@ impl Tensor { stride, dilation, } => { - // The output height for conv_transpose1d is: - // (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1 - let grad_l_in = grad.dim(2)?; - let k_size = kernel.dim(2)?; - let out_size = - (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding; - let out_padding = arg.dim(2)? - out_size; - let grad_arg = grad.conv_transpose1d( - kernel, - *padding, - out_padding, - *stride, - *dilation, - /* groups */ 1, - )?; - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&grad_arg)?; - - let grad_kernel = arg - .transpose(0, 1)? - .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? - .transpose(0, 1)?; - let sum_grad = grads.or_insert(kernel)?; - let (_, _, k0) = kernel.dims3()?; - let (_, _, g_k0) = grad_kernel.dims3()?; - let grad_kernel = if g_k0 != k0 { - grad_kernel.narrow(2, 0, k0)? - } else { - grad_kernel - }; - *sum_grad = sum_grad.add(&grad_kernel)?; + if is_tensor_grad_needed(arg) { + // The output height for conv_transpose1d is: + // (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1 + let grad_l_in = grad.dim(2)?; + let k_size = kernel.dim(2)?; + let out_size = (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 + - 2 * padding; + let out_padding = arg.dim(2)? - out_size; + let grad_arg = grad.conv_transpose1d( + kernel, + *padding, + out_padding, + *stride, + *dilation, + /* groups */ 1, + )?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + } + if is_tensor_grad_needed(kernel) { + let grad_kernel = arg + .transpose(0, 1)? + .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0) = kernel.dims3()?; + let (_, _, g_k0) = grad_kernel.dims3()?; + let grad_kernel = if g_k0 != k0 { + grad_kernel.narrow(2, 0, k0)? + } else { + grad_kernel + }; + *sum_grad = sum_grad.add(&grad_kernel)?; + } } Op::Conv2D { arg, @@ -278,36 +316,40 @@ impl Tensor { stride, dilation, } => { - // The output height for conv_transpose2d is: - // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1 - let grad_h = grad.dim(2)?; - let k_h = kernel.dim(2)?; - let out_size = - (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding; - let out_padding = arg.dim(2)? - out_size; - let grad_arg = grad.conv_transpose2d( - kernel, - *padding, - out_padding, - *stride, - *dilation, - )?; - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&grad_arg)?; + if is_tensor_grad_needed(arg) { + // The output height for conv_transpose2d is: + // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1 + let grad_h = grad.dim(2)?; + let k_h = kernel.dim(2)?; + let out_size = + (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding; + let out_padding = arg.dim(2)? - out_size; + let grad_arg = grad.conv_transpose2d( + kernel, + *padding, + out_padding, + *stride, + *dilation, + )?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + } - let grad_kernel = arg - .transpose(0, 1)? - .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? - .transpose(0, 1)?; - let sum_grad = grads.or_insert(kernel)?; - let (_, _, k0, k1) = kernel.dims4()?; - let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; - let grad_kernel = if g_k0 != k0 || g_k1 != k1 { - grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? - } else { - grad_kernel - }; - *sum_grad = sum_grad.add(&grad_kernel)?; + if is_tensor_grad_needed(kernel) { + let grad_kernel = arg + .transpose(0, 1)? + .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; + *sum_grad = sum_grad.add(&grad_kernel)?; + } } Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported { op: "conv-transpose1d", @@ -320,23 +362,27 @@ impl Tensor { dilation, output_padding: _output_padding, } => { - let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?; - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&grad_arg)?; + if is_tensor_grad_needed(arg) { + let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + } - let grad_kernel = grad - .transpose(0, 1)? - .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)? - .transpose(0, 1)?; - let sum_grad = grads.or_insert(kernel)?; - let (_, _, k0, k1) = kernel.dims4()?; - let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; - let grad_kernel = if g_k0 != k0 || g_k1 != k1 { - grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? - } else { - grad_kernel - }; - *sum_grad = sum_grad.add(&grad_kernel)?; + if is_tensor_grad_needed(kernel) { + let grad_kernel = grad + .transpose(0, 1)? + .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; + *sum_grad = sum_grad.add(&grad_kernel)?; + } } Op::AvgPool2D { arg, @@ -407,33 +453,45 @@ impl Tensor { *sum_grad = conv_sum; } Op::SliceScatter0(lhs, rhs, start_rhs) => { - let rhs_sum_grad = grads.or_insert(rhs)?; - let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; - *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + if is_tensor_grad_needed(rhs) { + let rhs_sum_grad = grads.or_insert(rhs)?; + let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } - let lhs_sum_grad = grads.or_insert(lhs)?; - let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)? + if is_tensor_grad_needed(lhs) { + let lhs_sum_grad = grads.or_insert(lhs)?; + let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)? + } } Op::Gather(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; } Op::ScatterAdd(init, indexes, src, dim) => { - let init_sum_grad = grads.or_insert(init)?; - *init_sum_grad = init_sum_grad.add(&grad)?; + if is_tensor_grad_needed(init) { + let init_sum_grad = grads.or_insert(init)?; + *init_sum_grad = init_sum_grad.add(&grad)?; + } - let src_grad = grad.gather(indexes, *dim)?; - let src_sum_grad = grads.or_insert(src)?; - *src_sum_grad = src_sum_grad.add(&src_grad)?; + if is_tensor_grad_needed(src) { + let src_grad = grad.gather(indexes, *dim)?; + let src_sum_grad = grads.or_insert(src)?; + *src_sum_grad = src_sum_grad.add(&src_grad)?; + } } Op::IndexAdd(init, indexes, src, dim) => { - let init_sum_grad = grads.or_insert(init)?; - *init_sum_grad = init_sum_grad.add(&grad)?; + if is_tensor_grad_needed(init) { + let init_sum_grad = grads.or_insert(init)?; + *init_sum_grad = init_sum_grad.add(&grad)?; + } - let src_grad = grad.index_select(indexes, *dim)?; - let src_sum_grad = grads.or_insert(src)?; - *src_sum_grad = src_sum_grad.add(&src_grad)?; + if is_tensor_grad_needed(src) { + let src_grad = grad.index_select(indexes, *dim)?; + let src_sum_grad = grads.or_insert(src)?; + *src_sum_grad = src_sum_grad.add(&src_grad)?; + } } Op::IndexSelect(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; @@ -443,21 +501,42 @@ impl Tensor { // Skipping checks, the op went ok, we can skip // the matmul size checks for now. - let lhs_grad = grad.matmul(&rhs.t()?)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + if is_tensor_grad_needed(lhs) { + let lhs_grad = grad.matmul(&rhs.t()?)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + } - let rhs_grad = lhs.t()?.matmul(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + if is_tensor_grad_needed(rhs) { + let rhs_grad = lhs.t()?.matmul(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } + } + Op::QMatmul(rhs, lhs) => { + let rhs = rhs.dequantize(self.device())?.t()?; + let rhs = match *grad.dims() { + [b1, b2, _, _] => rhs.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => rhs.broadcast_left(bsize)?.t()?, + _ => rhs.t()?, + }; + let rhs = rhs.detach(); + if is_tensor_grad_needed(lhs) { + let lhs_grad = grad.matmul(&rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + *lhs_sum_grad = lhs_sum_grad.detach(); + } } Op::Cat(args, dim) => { let mut start_idx = 0; for arg in args { let len = arg.dims()[*dim]; - let arg_grad = grad.narrow(*dim, start_idx, len)?; - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&arg_grad)?; + if is_tensor_grad_needed(arg) { + let arg_grad = grad.narrow(*dim, start_idx, len)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)?; + } start_idx += len; } } @@ -652,29 +731,39 @@ impl Tensor { } Op::CustomOp2(arg1, arg2, c) => { let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?; - if let Some(arg_grad1) = arg_grad1 { - let sum_grad = grads.or_insert(arg1)?; - *sum_grad = sum_grad.add(&arg_grad1)? + if is_tensor_grad_needed(arg1) { + if let Some(arg_grad1) = arg_grad1 { + let sum_grad = grads.or_insert(arg1)?; + *sum_grad = sum_grad.add(&arg_grad1)? + } } - if let Some(arg_grad2) = arg_grad2 { - let sum_grad = grads.or_insert(arg2)?; - *sum_grad = sum_grad.add(&arg_grad2)? + if is_tensor_grad_needed(arg2) { + if let Some(arg_grad2) = arg_grad2 { + let sum_grad = grads.or_insert(arg2)?; + *sum_grad = sum_grad.add(&arg_grad2)? + } } } Op::CustomOp3(arg1, arg2, arg3, c) => { let (arg_grad1, arg_grad2, arg_grad3) = c.bwd(arg1, arg2, arg3, node, &grad)?; - if let Some(arg_grad1) = arg_grad1 { - let sum_grad = grads.or_insert(arg1)?; - *sum_grad = sum_grad.add(&arg_grad1)? + if is_tensor_grad_needed(arg1) { + if let Some(arg_grad1) = arg_grad1 { + let sum_grad = grads.or_insert(arg1)?; + *sum_grad = sum_grad.add(&arg_grad1)? + } } - if let Some(arg_grad2) = arg_grad2 { - let sum_grad = grads.or_insert(arg2)?; - *sum_grad = sum_grad.add(&arg_grad2)? + if is_tensor_grad_needed(arg2) { + if let Some(arg_grad2) = arg_grad2 { + let sum_grad = grads.or_insert(arg2)?; + *sum_grad = sum_grad.add(&arg_grad2)? + } } - if let Some(arg_grad3) = arg_grad3 { - let sum_grad = grads.or_insert(arg3)?; - *sum_grad = sum_grad.add(&arg_grad3)? + if is_tensor_grad_needed(arg3) { + if let Some(arg_grad3) = arg_grad3 { + let sum_grad = grads.or_insert(arg3)?; + *sum_grad = sum_grad.add(&arg_grad3)? + } } } Op::Unary(arg, UnaryOp::Sqr) => { diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index c5fc3fc475..2e67803c15 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,9 +1,11 @@ //! Tensor Opertion Enums and Traits //! #![allow(clippy::redundant_closure_call)] +use crate::quantized::QTensor; use crate::Tensor; use half::{bf16, f16}; use num_traits::float::Float; +use std::sync::Arc; #[derive(Clone, Copy, PartialEq, Eq)] pub enum CmpOp { @@ -179,6 +181,8 @@ pub enum Op { Tensor, std::sync::Arc>, ), + + QMatmul(std::sync::Arc, Tensor), } pub trait UnaryOpT { @@ -918,6 +922,15 @@ impl BackpropOp { Self(op) } + pub(crate) fn new_qmatmul(arg1: &Arc, arg2: &Tensor) -> Self { + let op = if arg2.track_op() { + Some(Op::QMatmul(arg1.clone(), arg2.clone())) + } else { + None + }; + Self(op) + } + pub(crate) fn is_none(&self) -> bool { self.0.is_none() } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 802c5691f0..7277fc8234 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -527,7 +527,7 @@ impl crate::CustomOp1 for QTensor { impl crate::Module for QMatMul { fn forward(&self, xs: &Tensor) -> Result { match self { - Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()), + Self::QTensor(t) => xs.qmatmul(t), Self::Tensor(w) => { let w = match *xs.dims() { [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 3169928893..7ec4ebb6b3 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -6,6 +6,7 @@ use crate::scalar::TensorOrScalar; use crate::shape::{Dim, Dims}; use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; +use crate::quantized::QTensor; /// Unique identifier for tensors. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -1279,6 +1280,16 @@ impl Tensor { Ok(from_storage(storage, c_shape, op, false)) } + pub fn qmatmul(&self, rhs: &Arc) -> Result { + let (storage, shape) = self.storage().apply_op1(self.layout(), rhs.as_ref())?; + Ok(from_storage( + storage, + shape, + BackpropOp::new_qmatmul(rhs, self), + false + )) + } + /// Matrix-multiplication with broadcasting support. /// /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as