Skip to content

Commit

Permalink
Merge pull request #345 from robertknight/rten-model-u8i8
Browse files Browse the repository at this point in the history
Support u8 and i8 tensors in operator inputs, outputs and model files
  • Loading branch information
robertknight authored Sep 6, 2024
2 parents 9250c65 + 387cdb8 commit 3d50c14
Show file tree
Hide file tree
Showing 27 changed files with 902 additions and 226 deletions.
4 changes: 3 additions & 1 deletion rten-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ fn run_with_random_input(
for (i, (output, name)) in outputs.iter().zip(output_names).enumerate() {
let dtype = match output {
Output::FloatTensor(_) => "f32",
Output::IntTensor(_) => "i32",
Output::Int32Tensor(_) => "i32",
Output::Int8Tensor(_) => "i8",
Output::UInt8Tensor(_) => "u8",
};
println!(
" Output {i} \"{name}\" data type {} shape: {:?}",
Expand Down
34 changes: 25 additions & 9 deletions rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, name: str, shape: list[int], data: np.ndarray):

# Verify that this is a data type that we'll be able to serialize later.
match data.dtype:
case np.float32 | np.int32:
case np.float32 | np.int32 | np.int8 | np.uint8:
pass
case _:
dtype_name: str = data.dtype.name # type:ignore[union-attr]
Expand Down Expand Up @@ -439,11 +439,11 @@ def constant_node_from_onnx_initializer(

match data.dtype.name:
# Types that don't need to change
case "float32" | "int32":
case "float32" | "int8" | "int32" | "uint8":
pass

# Int types that can be widened to int32
case "bool" | "int8" | "int16":
case "bool" | "int16":
data = data.astype(np.int32)

# Types that need to be narrowed
Expand Down Expand Up @@ -1149,14 +1149,22 @@ def build_constant_node(
inline_data_type = sg.ConstantData.FloatData
dtype = sg.ConstantDataType.Float32
case np.int32:
inline_data_type = sg.ConstantData.IntData
inline_data_type = sg.ConstantData.Int32Data
dtype = sg.ConstantDataType.Int32
case np.int8:
inline_data_type = sg.ConstantData.Int8Data
dtype = sg.ConstantDataType.Int8
case np.uint8:
inline_data_type = sg.ConstantData.UInt8Data
dtype = sg.ConstantDataType.UInt8
case _:
raise ValueError(f"Unsupported data array type {constant.data.dtype.name}") # type:ignore[union-attr]

# Store inline if we're generating the V1 format, or the tensor is small.
# Small values are mostly parameters such as axes, slice ranges etc.
store_inline = tensor_data is None or n_elems <= 16
store_inline = (
tensor_data is None or n_elems <= 16
) and inline_data_type is not None
inline_data = None
data_offset = None

Expand All @@ -1168,12 +1176,20 @@ def build_constant_node(
sg.FloatDataAddData(builder, inline_data_vec)
inline_data = sg.FloatDataEnd(builder)
case np.int32:
sg.IntDataStart(builder)
sg.IntDataAddData(builder, inline_data_vec)
inline_data = sg.IntDataEnd(builder)
sg.Int32DataStart(builder)
sg.Int32DataAddData(builder, inline_data_vec)
inline_data = sg.Int32DataEnd(builder)
case np.int8:
sg.Int8DataStart(builder)
sg.Int8DataAddData(builder, inline_data_vec)
inline_data = sg.Int8DataEnd(builder)
case np.uint8:
sg.UInt8DataStart(builder)
sg.UInt8DataAddData(builder, inline_data_vec)
inline_data = sg.UInt8DataEnd(builder)
case _:
raise ValueError(
f"Unsupported data array type {constant.data.dtype.name}" # type:ignore
f"Unsupported data type for inline storage {constant.data.dtype.name}" # type:ignore
)
else:
assert tensor_data
Expand Down
Loading

0 comments on commit 3d50c14

Please sign in to comment.