Skip to content

Commit

Permalink
Pre-transpose constant MatMul operand
Browse files Browse the repository at this point in the history
  • Loading branch information
hsfzxjy committed Aug 22, 2024
1 parent 2024f4e commit 9b616de
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 19 deletions.
15 changes: 15 additions & 0 deletions rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ConstantNode(Node):
"""

shape: list[int]
strides: Optional[list[int]]
data: np.ndarray

def __init__(self, name: str, shape: list[int], data: np.ndarray):
Expand Down Expand Up @@ -811,6 +812,12 @@ def op_node_from_onnx_operator(
op_reader.check_attr("input_forget", "int", 0)
op_reader.check_attr("layout", "int", 0)

case "MatMul":
b = constant_nodes.get(onnx_op.input[-1])
if b and len(b.shape) == 2 and b.shape[-1] > 1:
b.data = np.ascontiguousarray(b.data.transpose())
b.strides = [1, b.shape[0]]

case "MaxPool":
attrs = sg.MaxPoolAttrsT()
kernel_shape = op_reader.require_attr("kernel_shape", "ints")
Expand Down Expand Up @@ -1141,6 +1148,12 @@ def build_constant_node(
shape_vec = write_vec(
builder, sg.ConstantNodeStartShapeVector, constant.shape, "u32"
)
if getattr(constant, "strides", None):
strides_vec = write_vec(
builder, sg.ConstantNodeStartStridesVector, constant.strides, "u32"
)
else:
strides_vec = None
n_elems = reduce(mul, constant.shape, 1)
assert n_elems == constant.data.size, "constant shape does not match element count"

Expand Down Expand Up @@ -1182,6 +1195,8 @@ def build_constant_node(
sg.ConstantNodeStart(builder)
sg.ConstantNodeAddShape(builder, shape_vec)
sg.ConstantNodeAddDtype(builder, dtype)
if strides_vec:
sg.ConstantNodeAddStrides(builder, strides_vec)

if inline_data:
sg.ConstantNodeAddDataType(builder, inline_data_type)
Expand Down
69 changes: 60 additions & 9 deletions rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5145,15 +5145,42 @@ def ShapeIsNone(self):
return o == 0

# ConstantNode
def DataType(self):
def Strides(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0

# ConstantNode
def StridesAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0

# ConstantNode
def StridesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.VectorLen(o)
return 0

# ConstantNode
def StridesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0

# ConstantNode
def DataType(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
return 0

# ConstantNode
def Data(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
from flatbuffers.table import Table
obj = Table(bytearray(), 0)
Expand All @@ -5163,38 +5190,44 @@ def Data(self):

# ConstantNode
def Dtype(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos)
return None

# ConstantNode
def DataOffset(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
return None

def ConstantNodeStart(builder):
builder.StartObject(5)
builder.StartObject(6)

def ConstantNodeAddShape(builder, shape):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)

def ConstantNodeStartShapeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def ConstantNodeAddStrides(builder, strides):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0)

def ConstantNodeStartStridesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def ConstantNodeAddDataType(builder, dataType):
builder.PrependUint8Slot(1, dataType, 0)
builder.PrependUint8Slot(2, dataType, 0)

def ConstantNodeAddData(builder, data):
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0)
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0)

def ConstantNodeAddDtype(builder, dtype):
builder.PrependUint16Slot(3, dtype, None)
builder.PrependUint16Slot(4, dtype, None)

def ConstantNodeAddDataOffset(builder, dataOffset):
builder.PrependUint64Slot(4, dataOffset, None)
builder.PrependUint64Slot(5, dataOffset, None)

def ConstantNodeEnd(builder):
return builder.EndObject()
Expand All @@ -5210,6 +5243,7 @@ class ConstantNodeT(object):
# ConstantNodeT
def __init__(self):
self.shape = None # type: List[int]
self.strides = None # type: List[int]
self.dataType = 0 # type: int
self.data = None # type: Union[None, FloatDataT, IntDataT]
self.dtype = None # type: Optional[int]
Expand Down Expand Up @@ -5243,6 +5277,13 @@ def _UnPack(self, constantNode):
self.shape.append(constantNode.Shape(i))
else:
self.shape = constantNode.ShapeAsNumpy()
if not constantNode.StridesIsNone():
if np is None:
self.strides = []
for i in range(constantNode.StridesLength()):
self.strides.append(constantNode.Strides(i))
else:
self.strides = constantNode.StridesAsNumpy()
self.dataType = constantNode.DataType()
self.data = ConstantDataCreator(self.dataType, constantNode.Data())
self.dtype = constantNode.Dtype()
Expand All @@ -5258,11 +5299,21 @@ def Pack(self, builder):
for i in reversed(range(len(self.shape))):
builder.PrependUint32(self.shape[i])
shape = builder.EndVector()
if self.strides is not None:
if np is not None and type(self.strides) is np.ndarray:
strides = builder.CreateNumpyVector(self.strides)
else:
ConstantNodeStartStridesVector(builder, len(self.strides))
for i in reversed(range(len(self.strides))):
builder.PrependUint32(self.strides[i])
strides = builder.EndVector()
if self.data is not None:
data = self.data.Pack(builder)
ConstantNodeStart(builder)
if self.shape is not None:
ConstantNodeAddShape(builder, shape)
if self.strides is not None:
ConstantNodeAddStrides(builder, strides)
ConstantNodeAddDataType(builder, self.dataType)
if self.data is not None:
ConstantNodeAddData(builder, data)
Expand Down
2 changes: 2 additions & 0 deletions src/constant_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ pub enum ConstantStorage {
/// An in-memory buffer, such as a FlatBuffers file that has been read
/// into memory using functions from `std::fs`.
Buffer(Vec<u8>),
StaticSlice(&'static [u8]),
}

impl ConstantStorage {
/// Return the data in this storage as a slice of bytes.
pub fn data(&self) -> &[u8] {
match &self {
ConstantStorage::Buffer(data) => data,
ConstantStorage::StaticSlice(data) => data,
#[cfg(feature = "mmap")]
ConstantStorage::Mmap(mmap) => mmap,
}
Expand Down
46 changes: 40 additions & 6 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ impl Model {
tensor_data_offset: Option<u64>,
) -> Result<NodeId, ModelLoadError> {
let shape: Vec<usize> = constant.shape().iter().map(|x| x as usize).collect();
let strides: Option<Vec<usize>> = constant
.strides()
.map(|strides| strides.iter().map(|x| x as usize).collect());

if let Some(data_offset) = constant.data_offset() {
// Constant data is stored outside the model buffer, in the same file.
Expand All @@ -486,13 +489,21 @@ impl Model {

let graph_node = match constant.dtype() {
Some(sg::ConstantDataType::Int32) => {
let const_data =
constant_data_from_storage_offset::<i32>(storage, &shape, data_offset)?;
let const_data = constant_data_from_storage_offset::<i32>(
storage,
&shape,
strides.as_deref(),
data_offset,
)?;
graph.add_constant(name, const_data)
}
Some(sg::ConstantDataType::Float32) => {
let const_data =
constant_data_from_storage_offset::<f32>(storage, &shape, data_offset)?;
let const_data = constant_data_from_storage_offset::<f32>(
storage,
&shape,
strides.as_deref(),
data_offset,
)?;
graph.add_constant(name, const_data)
}
_ => {
Expand Down Expand Up @@ -717,6 +728,7 @@ fn transmute_bytes<T: Pod>(bytes: &[u8]) -> Option<&[T]> {
fn constant_data_from_storage_offset<T: LeBytes + Pod>(
storage: &Arc<ConstantStorage>,
shape: &[usize],
strides: Option<&[usize]>,
offset: usize,
) -> Result<ConstantNodeData<T>, ModelLoadError> {
let n_elements: usize = shape.iter().product();
Expand All @@ -731,14 +743,36 @@ fn constant_data_from_storage_offset<T: LeBytes + Pod>(
if let Some(elements) = transmute_bytes(bytes) {
let storage =
ArcSlice::new(storage.clone(), elements).expect("storage does not contain data");
let const_data: ConstantNodeData<T> = ArcTensorView::from_data(shape, storage).into();
let const_data: ConstantNodeData<T> = if let Some(strides) = strides {
ArcTensorView::from_data_with_strides(shape, storage, strides)
.map_err(|_| {
ModelLoadError::GraphError(format!(
"bad strides = {:?}, shape = {:?}",
strides, shape
))
})?
.into()
} else {
ArcTensorView::from_data(shape, storage).into()
};
Ok(const_data)
} else {
let data: Vec<T> = bytes
.chunks(std::mem::size_of::<T>())
.map(|chunk| T::from_le_bytes(chunk.try_into().unwrap()))
.collect();
Ok(Tensor::from_data(shape, data).into())
Ok(if let Some(strides) = strides {
Tensor::from_data_with_strides(shape, data, strides)
.map_err(|_| {
ModelLoadError::GraphError(format!(
"bad strides = {:?}, shape = {:?}",
strides, shape
))
})?
.into()
} else {
Tensor::from_data(shape, data).into()
})
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {

sg::ConstantNodeArgs {
shape: Some(shape_vec),
strides: None,
data_type: sg::ConstantData::NONE,
data: None,
data_offset: Some(offset),
Expand All @@ -294,6 +295,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {

sg::ConstantNodeArgs {
shape: Some(shape_vec),
strides: None,
data_type: inline_dtype,
data: Some(data),
data_offset: None,
Expand Down
1 change: 1 addition & 0 deletions src/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ enum ConstantDataType: ushort {
// Graph node for a constant tensor value, whose data is part of the model.
table ConstantNode {
shape:[uint] (required);
strides:[uint];

// Tensor data embedded within the model file.
data:ConstantData;
Expand Down
Loading

0 comments on commit 9b616de

Please sign in to comment.