diff --git a/rten-convert/rten_convert/converter.py b/rten-convert/rten_convert/converter.py index f33875cd..468d9d7c 100644 --- a/rten-convert/rten_convert/converter.py +++ b/rten-convert/rten_convert/converter.py @@ -55,6 +55,7 @@ class ConstantNode(Node): """ shape: list[int] + strides: Optional[list[int]] data: np.ndarray def __init__(self, name: str, shape: list[int], data: np.ndarray): @@ -811,6 +812,12 @@ def op_node_from_onnx_operator( op_reader.check_attr("input_forget", "int", 0) op_reader.check_attr("layout", "int", 0) + case "MatMul": + b = constant_nodes.get(onnx_op.input[-1]) + if b and len(b.shape) == 2 and b.shape[-1] > 1: + b.data = np.ascontiguousarray(b.data.transpose()) + b.strides = [1, b.shape[0]] + case "MaxPool": attrs = sg.MaxPoolAttrsT() kernel_shape = op_reader.require_attr("kernel_shape", "ints") @@ -1141,6 +1148,12 @@ def build_constant_node( shape_vec = write_vec( builder, sg.ConstantNodeStartShapeVector, constant.shape, "u32" ) + if getattr(constant, "strides", None): + strides_vec = write_vec( + builder, sg.ConstantNodeStartStridesVector, constant.strides, "u32" + ) + else: + strides_vec = None n_elems = reduce(mul, constant.shape, 1) assert n_elems == constant.data.size, "constant shape does not match element count" @@ -1182,6 +1195,8 @@ def build_constant_node( sg.ConstantNodeStart(builder) sg.ConstantNodeAddShape(builder, shape_vec) sg.ConstantNodeAddDtype(builder, dtype) + if strides_vec: + sg.ConstantNodeAddStrides(builder, strides_vec) if inline_data: sg.ConstantNodeAddDataType(builder, inline_data_type) diff --git a/rten-convert/rten_convert/schema_generated.py b/rten-convert/rten_convert/schema_generated.py index 2ff066f1..6e92f61e 100644 --- a/rten-convert/rten_convert/schema_generated.py +++ b/rten-convert/rten_convert/schema_generated.py @@ -5145,15 +5145,42 @@ def ShapeIsNone(self): return o == 0 # ConstantNode - def DataType(self): + def Strides(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ConstantNode + def StridesAsNumpy(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # ConstantNode + def StridesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ConstantNode + def StridesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # ConstantNode + def DataType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) return 0 # ConstantNode def Data(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: from flatbuffers.table import Table obj = Table(bytearray(), 0) @@ -5163,20 +5190,20 @@ def Data(self): # ConstantNode def Dtype(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos) return None # ConstantNode def DataOffset(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos) return None def ConstantNodeStart(builder): - builder.StartObject(5) + builder.StartObject(6) def ConstantNodeAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) @@ -5184,17 +5211,23 @@ def ConstantNodeAddShape(builder, shape): def ConstantNodeStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ConstantNodeAddStrides(builder, strides): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0) + +def ConstantNodeStartStridesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + def ConstantNodeAddDataType(builder, dataType): - builder.PrependUint8Slot(1, dataType, 0) + builder.PrependUint8Slot(2, dataType, 0) def ConstantNodeAddData(builder, data): - builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0) + builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0) def ConstantNodeAddDtype(builder, dtype): - builder.PrependUint16Slot(3, dtype, None) + builder.PrependUint16Slot(4, dtype, None) def ConstantNodeAddDataOffset(builder, dataOffset): - builder.PrependUint64Slot(4, dataOffset, None) + builder.PrependUint64Slot(5, dataOffset, None) def ConstantNodeEnd(builder): return builder.EndObject() @@ -5210,6 +5243,7 @@ class ConstantNodeT(object): # ConstantNodeT def __init__(self): self.shape = None # type: List[int] + self.strides = None # type: List[int] self.dataType = 0 # type: int self.data = None # type: Union[None, FloatDataT, IntDataT] self.dtype = None # type: Optional[int] @@ -5243,6 +5277,13 @@ def _UnPack(self, constantNode): self.shape.append(constantNode.Shape(i)) else: self.shape = constantNode.ShapeAsNumpy() + if not constantNode.StridesIsNone(): + if np is None: + self.strides = [] + for i in range(constantNode.StridesLength()): + self.strides.append(constantNode.Strides(i)) + else: + self.strides = constantNode.StridesAsNumpy() self.dataType = constantNode.DataType() self.data = ConstantDataCreator(self.dataType, constantNode.Data()) self.dtype = constantNode.Dtype() @@ -5258,11 +5299,21 @@ def Pack(self, builder): for i in reversed(range(len(self.shape))): builder.PrependUint32(self.shape[i]) shape = builder.EndVector() + if self.strides is not None: + if np is not None and type(self.strides) is np.ndarray: + strides = builder.CreateNumpyVector(self.strides) + else: + ConstantNodeStartStridesVector(builder, len(self.strides)) + for i in reversed(range(len(self.strides))): + builder.PrependUint32(self.strides[i]) + strides = builder.EndVector() if self.data is not None: data = self.data.Pack(builder) ConstantNodeStart(builder) if self.shape is not None: ConstantNodeAddShape(builder, shape) + if self.strides is not None: + ConstantNodeAddStrides(builder, strides) ConstantNodeAddDataType(builder, self.dataType) if self.data is not None: ConstantNodeAddData(builder, data) diff --git a/src/constant_storage.rs b/src/constant_storage.rs index aff271fd..b8e160b1 100644 --- a/src/constant_storage.rs +++ b/src/constant_storage.rs @@ -31,6 +31,7 @@ pub enum ConstantStorage { /// An in-memory buffer, such as a FlatBuffers file that has been read /// into memory using functions from `std::fs`. Buffer(Vec), + StaticSlice(&'static [u8]), } impl ConstantStorage { @@ -38,6 +39,7 @@ impl ConstantStorage { pub fn data(&self) -> &[u8] { match &self { ConstantStorage::Buffer(data) => data, + ConstantStorage::StaticSlice(data) => data, #[cfg(feature = "mmap")] ConstantStorage::Mmap(mmap) => mmap, } diff --git a/src/model.rs b/src/model.rs index 7e1b7bf4..a637d8a0 100644 --- a/src/model.rs +++ b/src/model.rs @@ -473,6 +473,9 @@ impl Model { tensor_data_offset: Option, ) -> Result { let shape: Vec = constant.shape().iter().map(|x| x as usize).collect(); + let strides: Option> = constant + .strides() + .map(|strides| strides.iter().map(|x| x as usize).collect()); if let Some(data_offset) = constant.data_offset() { // Constant data is stored outside the model buffer, in the same file. @@ -486,13 +489,21 @@ impl Model { let graph_node = match constant.dtype() { Some(sg::ConstantDataType::Int32) => { - let const_data = - constant_data_from_storage_offset::(storage, &shape, data_offset)?; + let const_data = constant_data_from_storage_offset::( + storage, + &shape, + strides.as_deref(), + data_offset, + )?; graph.add_constant(name, const_data) } Some(sg::ConstantDataType::Float32) => { - let const_data = - constant_data_from_storage_offset::(storage, &shape, data_offset)?; + let const_data = constant_data_from_storage_offset::( + storage, + &shape, + strides.as_deref(), + data_offset, + )?; graph.add_constant(name, const_data) } _ => { @@ -717,6 +728,7 @@ fn transmute_bytes(bytes: &[u8]) -> Option<&[T]> { fn constant_data_from_storage_offset( storage: &Arc, shape: &[usize], + strides: Option<&[usize]>, offset: usize, ) -> Result, ModelLoadError> { let n_elements: usize = shape.iter().product(); @@ -731,14 +743,36 @@ fn constant_data_from_storage_offset( if let Some(elements) = transmute_bytes(bytes) { let storage = ArcSlice::new(storage.clone(), elements).expect("storage does not contain data"); - let const_data: ConstantNodeData = ArcTensorView::from_data(shape, storage).into(); + let const_data: ConstantNodeData = if let Some(strides) = strides { + ArcTensorView::from_data_with_strides(shape, storage, strides) + .map_err(|_| { + ModelLoadError::GraphError(format!( + "bad strides = {:?}, shape = {:?}", + strides, shape + )) + })? + .into() + } else { + ArcTensorView::from_data(shape, storage).into() + }; Ok(const_data) } else { let data: Vec = bytes .chunks(std::mem::size_of::()) .map(|chunk| T::from_le_bytes(chunk.try_into().unwrap())) .collect(); - Ok(Tensor::from_data(shape, data).into()) + Ok(if let Some(strides) = strides { + Tensor::from_data_with_strides(shape, data, strides) + .map_err(|_| { + ModelLoadError::GraphError(format!( + "bad strides = {:?}, shape = {:?}", + strides, shape + )) + })? + .into() + } else { + Tensor::from_data(shape, data).into() + }) } } diff --git a/src/model_builder.rs b/src/model_builder.rs index 74e66ca0..f2d6f93c 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -283,6 +283,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { sg::ConstantNodeArgs { shape: Some(shape_vec), + strides: None, data_type: sg::ConstantData::NONE, data: None, data_offset: Some(offset), @@ -294,6 +295,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { sg::ConstantNodeArgs { shape: Some(shape_vec), + strides: None, data_type: inline_dtype, data: Some(data), data_offset: None, diff --git a/src/schema.fbs b/src/schema.fbs index 1c30a49a..e4657352 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -480,6 +480,7 @@ enum ConstantDataType: ushort { // Graph node for a constant tensor value, whose data is part of the model. table ConstantNode { shape:[uint] (required); + strides:[uint]; // Tensor data embedded within the model file. data:ConstantData; diff --git a/src/schema_generated.rs b/src/schema_generated.rs index 4d46bca0..9affd266 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -8734,10 +8734,11 @@ impl<'a> flatbuffers::Follow<'a> for ConstantNode<'a> { impl<'a> ConstantNode<'a> { pub const VT_SHAPE: flatbuffers::VOffsetT = 4; - pub const VT_DATA_TYPE: flatbuffers::VOffsetT = 6; - pub const VT_DATA: flatbuffers::VOffsetT = 8; - pub const VT_DTYPE: flatbuffers::VOffsetT = 10; - pub const VT_DATA_OFFSET: flatbuffers::VOffsetT = 12; + pub const VT_STRIDES: flatbuffers::VOffsetT = 6; + pub const VT_DATA_TYPE: flatbuffers::VOffsetT = 8; + pub const VT_DATA: flatbuffers::VOffsetT = 10; + pub const VT_DTYPE: flatbuffers::VOffsetT = 12; + pub const VT_DATA_OFFSET: flatbuffers::VOffsetT = 14; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -8755,6 +8756,9 @@ impl<'a> ConstantNode<'a> { if let Some(x) = args.data { builder.add_data(x); } + if let Some(x) = args.strides { + builder.add_strides(x); + } if let Some(x) = args.shape { builder.add_shape(x); } @@ -8780,6 +8784,19 @@ impl<'a> ConstantNode<'a> { } } #[inline] + pub fn strides(&self) -> Option> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + ConstantNode::VT_STRIDES, + None, + ) + } + } + #[inline] pub fn data_type(&self) -> ConstantData { // Safety: // Created from valid Table for this object @@ -8864,6 +8881,11 @@ impl flatbuffers::Verifiable for ConstantNode<'_> { Self::VT_SHAPE, true, )? + .visit_field::>>( + "strides", + Self::VT_STRIDES, + false, + )? .visit_union::( "data_type", Self::VT_DATA_TYPE, @@ -8892,6 +8914,7 @@ impl flatbuffers::Verifiable for ConstantNode<'_> { } pub struct ConstantNodeArgs<'a> { pub shape: Option>>, + pub strides: Option>>, pub data_type: ConstantData, pub data: Option>, pub dtype: Option, @@ -8902,6 +8925,7 @@ impl<'a> Default for ConstantNodeArgs<'a> { fn default() -> Self { ConstantNodeArgs { shape: None, // required field + strides: None, data_type: ConstantData::NONE, data: None, dtype: None, @@ -8921,6 +8945,11 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ConstantNodeBuilder<'a, 'b, A> .push_slot_always::>(ConstantNode::VT_SHAPE, shape); } #[inline] + pub fn add_strides(&mut self, strides: flatbuffers::WIPOffset>) { + self.fbb_ + .push_slot_always::>(ConstantNode::VT_STRIDES, strides); + } + #[inline] pub fn add_data_type(&mut self, data_type: ConstantData) { self.fbb_.push_slot::( ConstantNode::VT_DATA_TYPE, @@ -8965,6 +8994,7 @@ impl core::fmt::Debug for ConstantNode<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut ds = f.debug_struct("ConstantNode"); ds.field("shape", &self.shape()); + ds.field("strides", &self.strides()); ds.field("data_type", &self.data_type()); match self.data_type() { ConstantData::FloatData => {