Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IREE EP][JIT, OnnxImporter] Synchronize C++ and Python importers. #7

Open
wants to merge 1 commit into
base: iree_ep
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions onnxruntime/core/providers/iree/compiler/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// (which require a pre-compilation step).

#include "core/providers/iree/compiler/jit_compiler.h"
#include "core/graph/graph_proto_serializer.h"
#include "core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h"
#include "mlir-c/BuiltinAttributes.h"

Expand Down Expand Up @@ -157,9 +158,8 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
opset_import->set_version(it.second);
}

// Unforgivably sharp edge: There is a ToGraphProto() that returns a value and another that returns a reference.
// And they differ by const-ness. We need to make sure we get the reference, obviously, so we assign it explicitly.
const ONNX_NAMESPACE::GraphProto& graph_proto = graph_view.GetGraph().ToGraphProto();
ONNX_NAMESPACE::GraphProto graph_proto;
GraphViewerToProto(graph_view, graph_proto, true, true);
// LOGS(session.logger, INFO) << " full graph: " << graph_proto.DebugString();

// Set up for subgraph import.
Expand Down Expand Up @@ -193,23 +193,9 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
model_info.error_message(), ConsumeDiagnostics());
}

// Import each node. Note that the importer uses references internally and expects nodes to be located at fixed
// memory locations for the life of iteration. So we materialize them into a fixed vector first. This is because
// the onnxruntime does not keep the serialized proto form sync'd on its own.
auto node_indices = graph_view.GetNodesInTopologicalOrder();
std::vector<ONNX_NAMESPACE::NodeProto> nodes(node_indices.size());
for (size_t i = 0; i < node_indices.size(); ++i) {
graph_view.GetNode(node_indices[i])->ToProto(nodes[i]);
}
for (const auto& node : nodes) {
if (torch_mlir_onnx::failed(imp.ImportNode(node))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to import node '", node.name(), "': ",
model_info.error_message(), " (node:\n", node.DebugString(), "\n)", ConsumeDiagnostics());
}
}

// Finalize.
if (torch_mlir_onnx::failed(imp.FinalizeGraph())) {
imp.ImportNoneConstant();
// Import all nodes together, including the initializers.
if (torch_mlir_onnx::failed(imp.ImportAll(graph_view))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, model_info.error_message(), ConsumeDiagnostics());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "OnnxImporter.h"

#include "core/graph/graph_viewer.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"

Expand Down Expand Up @@ -227,6 +228,7 @@ Status GraphInfo::Initialize() {
return success();
}

#if 0
const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) {
// Node outputs don't typically have type information, but shape inference
// will associate them in the value_info. If not there, it may be a
Expand All @@ -250,12 +252,20 @@ const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) {
model_info_.SetError(std::move(msg));
return nullptr;
}
#endif

// ---------------------------------------------------------------------------//
// ContextCache
// ---------------------------------------------------------------------------//

MlirType ContextCache::GetNoneType() {
return mlirTypeParseGet(context_, toMlirStringRef("!torch.none"));
}

MlirType ContextCache::ConvertTypeProto(const onnx::TypeProto &tp) {
if(tp.value_case() == onnx::TypeProto::VALUE_NOT_SET)
return GetNoneType();

if (tp.has_tensor_type()) {
// Convert Tensor TypeProto.
const onnx::TypeProto_Tensor &tt = tp.tensor_type();
Expand Down Expand Up @@ -503,6 +513,22 @@ NodeImporter::NodeImporter(GraphInfo &graph_info, ContextCache &cc,
/*childLoc=*/{nullptr});
}

void NodeImporter::ImportNoneConstant() {
auto found_it = nv_map_.find("");
if (found_it != nv_map_.end())
return;

// Create an empty node(i.e. a none val), and place it in nv_map_.
// This function should be called _just once_, to avoid multiple
// nodes producing the same none value. Once set, there is really
// no need to put "" in nv_map again.
MlirOperation op = createMlirOperationAtEnd(
body_block_, "torch.constant.none", default_loc_, cc_.GetNoneType());
MlirValue nne = mlirOperationGetResult(op, 0);
// Place into nv_map.
nv_map_[""] = nne;
}

Status NodeImporter::DefineFunction(std::optional<std::string> name,
MlirOperation *out_function_op) {
const onnx::GraphProto &p = graph_info_.graph_proto();
Expand Down Expand Up @@ -613,15 +639,15 @@ void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) {
mlirStringAttrGet(context_, toMlirStringRef(m.producer_version())));
}

Status NodeImporter::ImportAll() {
Status NodeImporter::ImportAll(const onnxruntime::GraphViewer &gv) {
// TODO: Consider pulling in initializers on demand since there can be so
// much unused crap.
for (auto it : graph_info_.initializer_map()) {
if (failed(ImportInitializer(it.second)))
return failure();
}
for (auto it : graph_info_.graph_proto().node()) {
if (failed(ImportNode(it)))
if (failed(ImportNode(it, gv)))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to report what is failing for easier debugging?

Copy link
Author

@vinayakdsci vinayakdsci Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The errors should propagate through the function call stack. Thus the errors are usually reported at the last point of failure. If we report errors at each function in the error path, there will be too many messages and the actual problem might get lost
Your thoughts?

return failure();
}

Expand Down Expand Up @@ -678,17 +704,19 @@ Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) {
return success();
}

Status NodeImporter::ImportNode(const onnx::NodeProto &node) {
Status NodeImporter::ImportNode(const onnx::NodeProto &node,
const onnxruntime::GraphViewer &gv) {
std::string_view op_type = node.op_type();
// Handle special-form op types that do not go down the generic path.
if (op_type == "ConstantOfShape") {
return ImportConstantOfShapeNode(node);
}

return ImportGeneralNode(node);
return ImportGeneralNode(node, gv);
}

Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node,
const onnxruntime::GraphViewer &gv) {
MlirLocation loc =
node.name().empty()
? mlirLocationUnknownGet(context_)
Expand All @@ -712,7 +740,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
std::vector<MlirType> output_types;
for (auto &output_name : node.output()) {
const onnx::TypeProto *type_proto =
graph_info_.FindTypeProtoForName(output_name);
gv.GetNodeArg(output_name)->TypeAsProto();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the main reason for passing the graph view? I'd like to revisit the need to use graph viewer, and whether one of the structures, such as GraphInfo should just store the graph viewer instead of passing it around as an an argument in these importer methods.

Is the problem that FindTypeProtoForName isn't finding the associated type proto for initializers?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually yes, this is the main reason. I could very well store the whole graph viewer, but I was wary of the cost it would have in terms of memory. That could very well be premature optimisation, now that you mention this.

The problem as I figure it out is that it could be possible (this is just an observation, and I might be wrong here, but my grounds for this belief are rather solid) that the GraphProto that we get from graph viewer does not reflect all the nodes that live in the graph.

It would of course have been different if we were operating on an unmodified onnx protobuf obtained directly from the model, as we have in the case of the python importer.

Do you think that passing graph viewer as an arg is bad style? I think this would warrant a change then. I will push that along with the next commit.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type info issue doesn't just happen for initializers, but a good number of other values as well. I wanted to be sure about this, so I inserted a check to inform if the name was present in initializer_map. For a significant number of such failures, it wasn't.
Of course, if I am able to find out that something else is going on, I will be revising this. This does appear to be the most complete solution at the moment, though.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see if I understand this correctly. When an inference session is initialized for a model.onnx file (with whatever session options and EP), does ORT load the onnx model, then pass off just a graph view to the EP for further processing?

I'm guessing this is the case, since we are getting the graph from a graph view rather than a onnx::ModelProto.

If a graph view is our starting point for ORT, would it make more sense to just store the graph view and not the graph in the GraphInfo structure? Does graph view have all the same capabilities of a graph (e.g., the ability to read nodes and value info)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does ORT load the onnx model, then pass off just a graph view to the EP for further processing?

Yes, that is correct. This is the essence of the matter.

Does graph view have all the same capabilities of a graph (e.g., the ability to read nodes and value info)?

It actually does, and more. The graph that we operate on currently is obtained from graph viewer, and thus we can think of graph viewer as a super view of the actual onnx graph.

If a graph view is our starting point for ORT, would it make more sense to just store the graph view and not the graph in the GraphInfo structure?

This, unfortunately, requires some thought. All the processing that we do outside of internal classes like GraphInfo that works on graph viewer would now require handling inside these class and structs.
Also, I am doubtful about the ergonomics of storing the whole graph viewer, especially unless we heavily use it. Extracting values from it is something we'll still have to do.

In conclusion, storing graph viewer is actually a very viable option, though it will have to be unwrapped somewhere later down the execution path.

if (!type_proto)
return failure();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
// for class members/accessors because canonical protobuf coding presumes
// this kind of style.

#include "core/graph/graph_viewer.h"
#include "mlir-c/IR.h"
#include "onnx/onnx_pb.h"

Expand Down Expand Up @@ -182,6 +183,7 @@ class ContextCache {
/// the null type.
MlirType GetVtensorType(const std::vector<int64_t> &dims,
MlirType element_type);
MlirType GetNoneType();

private:
ModelInfo &model_info_;
Expand All @@ -206,10 +208,13 @@ class NodeImporter {
MlirOperation *out_function_op = nullptr);

/// Imports all nodes topologically. Internally calls FinalizeGraph.
Status ImportAll();
Status ImportAll(const onnxruntime::GraphViewer &gv);

/// Imports !torch.none constant values that replace inputs that do not have names.
void ImportNoneConstant();

/// Import nodes one at a time. Must complete with a call to FinalizeGraph.
Status ImportNode(const onnx::NodeProto &node);
Status ImportNode(const onnx::NodeProto &node, const onnxruntime::GraphViewer &gv);
Status FinalizeGraph();

void DebugDumpModule();
Expand All @@ -220,7 +225,7 @@ class NodeImporter {
MlirAttribute ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr);

// Special-form nodes.
Status ImportGeneralNode(const onnx::NodeProto &node);
Status ImportGeneralNode(const onnx::NodeProto &node, const onnxruntime::GraphViewer &gv);
Status ImportConstantOfShapeNode(const onnx::NodeProto &node);

/// Looks for an initializer for `name` and attempts to treat it as a 1D
Expand Down