Skip to content

Commit

Permalink
Merge pull request #372 from robertknight/widen-float16
Browse files Browse the repository at this point in the history
Widen float16 constants to float32 in model converter
  • Loading branch information
robertknight authored Sep 21, 2024
2 parents 3463b62 + f781ff7 commit 74ce370
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,16 +436,26 @@ def constant_node_from_onnx_initializer(
) -> ConstantNode:
dims = list(tensor.dims)
data = numpy_helper.to_array(tensor)
dtype_name = data.dtype.name

match data.dtype.name:
match dtype_name:
# Types that don't need to change
case "float32" | "int8" | "int32" | "uint8":
pass

# Int types that can be widened to int32
# Int types that are not supported natively, but can be widened to
# int32.
case "bool" | "int16":
data = data.astype(np.int32)

# Float types that are not supported natively, but can be widened to
# float32.
case "float16":
warn_once(
f"Converting {dtype_name} weights to float32 because {dtype_name} is not supported natively yet. This will increase model size."
)
data = data.astype(np.float32)

# Types that need to be narrowed
case "int64":
# Some ONNX exporters use `INT_MIN` and `INT_MAX` to represent
Expand Down

0 comments on commit 74ce370

Please sign in to comment.