Skip to content

Commit

Permalink
Fix onnx importer to treat Constant values as static (#2780)
Browse files Browse the repository at this point in the history
Fixes  #2764

In the case of OPT, there are ConstantOfShape ops whose input shape is
not static (that is, an initializer), but rather comes from a Constant
op. The importer can't handle such non-static input shapes.

The fix here is to create initializers for a subset of Constant ops
(ones with "value" attributes), so that their outputs can be used
statically. Additionally, there was no case for creating a splat of
int64, so I added that as well.

---------

Co-authored-by: Dave Liddell <[email protected]>
  • Loading branch information
daveliddell and Dave Liddell authored Jan 22, 2024
1 parent cad98e8 commit d452c4f
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,13 @@ def import_node(self, node: onnx.NodeProto):
with InsertionPoint(self._b), Location.name(node.name):
op_type = node.op_type
# Handle special op types that materialize to non-op IR constructs.
# Handlers return True if the op was handled, else this function
# should process it as a general node.
special_key = f"_handle_node_{op_type}"
if hasattr(self, special_key):
getattr(self, special_key)(node)
return
was_handled = getattr(self, special_key)(node)
if was_handled:
return

# General node import.
input_values = []
Expand Down Expand Up @@ -333,16 +336,19 @@ def import_attributes(
)
attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc)

def import_initializer(self, initializer: onnx.TensorProto) -> Value:
with InsertionPoint(self._b), Location.name(initializer.name):
def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value:
# If an explicitly specified name is given, use that; otherwise, pick
# up the name from the tensor proto itself
iname = extern_name if extern_name else initializer.name
with InsertionPoint(self._b), Location.name(iname):
value_attr = self._cc.tensor_proto_to_attr(initializer)
vtensor_type = self._cc.tensor_proto_to_type(initializer)
literal_op = Operation.create(
name="torch.vtensor.literal",
results=[vtensor_type],
attributes={"value": value_attr},
)
self._nv_map[initializer.name] = literal_op.result
self._nv_map[iname] = literal_op.result
return literal_op.result

def _get_immediate_tensor(self, name: str) -> np.array:
Expand All @@ -366,7 +372,23 @@ def _get_immediate_tensor(self, name: str) -> np.array:
f"Unhandled ONNX TensorProto immediate data: {initializer}"
)

def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
def _handle_node_Constant(self, node: onnx.NodeProto) -> bool:
# Special case only for constants specified by value attribute (for now)
value_proto = _get_attr(node, "value", False)
if not value_proto:
return False

# Produce an initializer for the constant, so that it can be used in
# combination with other ops, such as ConstantOfShape, requiring
# a constant input
assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR
assert len(node.output) == 1
const_name = node.output[0]
self.import_initializer(value_proto.t, const_name)
self._gi.initializer_map[const_name] = value_proto.t
return True

def _handle_node_ConstantOfShape(self, node: onnx.NodeProto) -> bool:
# This op is special: It has an input of the shape, and in full generality
# could involve eager production of constants of variable size. In
# practice, the DNN profile for ONNX makes this very difficult to do
Expand Down Expand Up @@ -394,6 +416,7 @@ def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
attributes={"value": value_attr},
)
self._nv_map[node.output[0]] = literal_op.result
return True


class ContextCache:
Expand Down Expand Up @@ -515,6 +538,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat(
RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0])
),
onnx.TensorProto.DataType.INT64: lambda tp, shape: DenseElementsAttr.get_splat(
RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get(
IntegerType.get_signed(64), int.from_bytes(tp.raw_data, "little",
signed=True) if tp.HasField("raw_data") else tp.int64_data[0])
),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
}

Expand Down Expand Up @@ -605,9 +633,10 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
}


def _get_attr(node: onnx.NodeProto, attr_name: str) -> onnx.AttributeProto:
def _get_attr(node: onnx.NodeProto, attr_name: str, is_required: bool = True) -> onnx.AttributeProto:
for attr in node.attribute:
if attr.name == attr_name:
return attr
else:
if is_required:
raise OnnxImportError(f"Required attribute {attr_name} not found in {node}")
return None

0 comments on commit d452c4f

Please sign in to comment.