From d452c4f4c0d5b8fe748eebb1b9801f44a7d13374 Mon Sep 17 00:00:00 2001 From: Dave Liddell <44620210+daveliddell@users.noreply.github.com> Date: Mon, 22 Jan 2024 14:00:05 -0700 Subject: [PATCH] Fix onnx importer to treat Constant values as static (#2780) Fixes https://github.com/llvm/torch-mlir/issues/2764 In the case of OPT, there are ConstantOfShape ops whose input shape is not static (that is, an initializer), but rather comes from a Constant op. The importer can't handle such non-static input shapes. The fix here is to create initializers for a subset of Constant ops (ones with "value" attributes), so that their outputs can be used statically. Additionally, there was no case for creating a splat of int64, so I added that as well. --------- Co-authored-by: Dave Liddell --- python/torch_mlir/extras/onnx_importer.py | 45 +++++++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index dbf0adc490bd..59a2682bbba9 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -276,10 +276,13 @@ def import_node(self, node: onnx.NodeProto): with InsertionPoint(self._b), Location.name(node.name): op_type = node.op_type # Handle special op types that materialize to non-op IR constructs. + # Handlers return True if the op was handled, else this function + # should process it as a general node. special_key = f"_handle_node_{op_type}" if hasattr(self, special_key): - getattr(self, special_key)(node) - return + was_handled = getattr(self, special_key)(node) + if was_handled: + return # General node import. input_values = [] @@ -333,8 +336,11 @@ def import_attributes( ) attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc) - def import_initializer(self, initializer: onnx.TensorProto) -> Value: - with InsertionPoint(self._b), Location.name(initializer.name): + def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value: + # If an explicitly specified name is given, use that; otherwise, pick + # up the name from the tensor proto itself + iname = extern_name if extern_name else initializer.name + with InsertionPoint(self._b), Location.name(iname): value_attr = self._cc.tensor_proto_to_attr(initializer) vtensor_type = self._cc.tensor_proto_to_type(initializer) literal_op = Operation.create( @@ -342,7 +348,7 @@ def import_initializer(self, initializer: onnx.TensorProto) -> Value: results=[vtensor_type], attributes={"value": value_attr}, ) - self._nv_map[initializer.name] = literal_op.result + self._nv_map[iname] = literal_op.result return literal_op.result def _get_immediate_tensor(self, name: str) -> np.array: @@ -366,7 +372,23 @@ def _get_immediate_tensor(self, name: str) -> np.array: f"Unhandled ONNX TensorProto immediate data: {initializer}" ) - def _handle_node_ConstantOfShape(self, node: onnx.NodeProto): + def _handle_node_Constant(self, node: onnx.NodeProto) -> bool: + # Special case only for constants specified by value attribute (for now) + value_proto = _get_attr(node, "value", False) + if not value_proto: + return False + + # Produce an initializer for the constant, so that it can be used in + # combination with other ops, such as ConstantOfShape, requiring + # a constant input + assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR + assert len(node.output) == 1 + const_name = node.output[0] + self.import_initializer(value_proto.t, const_name) + self._gi.initializer_map[const_name] = value_proto.t + return True + + def _handle_node_ConstantOfShape(self, node: onnx.NodeProto) -> bool: # This op is special: It has an input of the shape, and in full generality # could involve eager production of constants of variable size. In # practice, the DNN profile for ONNX makes this very difficult to do @@ -394,6 +416,7 @@ def _handle_node_ConstantOfShape(self, node: onnx.NodeProto): attributes={"value": value_attr}, ) self._nv_map[node.output[0]] = literal_op.result + return True class ContextCache: @@ -515,6 +538,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat( RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0]) ), + onnx.TensorProto.DataType.INT64: lambda tp, shape: DenseElementsAttr.get_splat( + RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get( + IntegerType.get_signed(64), int.from_bytes(tp.raw_data, "little", + signed=True) if tp.HasField("raw_data") else tp.int64_data[0]) + ), # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB } @@ -605,9 +633,10 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: } -def _get_attr(node: onnx.NodeProto, attr_name: str) -> onnx.AttributeProto: +def _get_attr(node: onnx.NodeProto, attr_name: str, is_required: bool = True) -> onnx.AttributeProto: for attr in node.attribute: if attr.name == attr_name: return attr - else: + if is_required: raise OnnxImportError(f"Required attribute {attr_name} not found in {node}") + return None