-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IREE EP][JIT, OnnxImporter] Synchronize C++ and Python importers. #7
base: iree_ep
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
|
||
#include "OnnxImporter.h" | ||
|
||
#include "core/graph/graph_viewer.h" | ||
#include "mlir-c/BuiltinAttributes.h" | ||
#include "mlir-c/BuiltinTypes.h" | ||
|
||
|
@@ -227,6 +228,7 @@ Status GraphInfo::Initialize() { | |
return success(); | ||
} | ||
|
||
#if 0 | ||
const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) { | ||
// Node outputs don't typically have type information, but shape inference | ||
// will associate them in the value_info. If not there, it may be a | ||
|
@@ -250,12 +252,20 @@ const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) { | |
model_info_.SetError(std::move(msg)); | ||
return nullptr; | ||
} | ||
#endif | ||
|
||
// ---------------------------------------------------------------------------// | ||
// ContextCache | ||
// ---------------------------------------------------------------------------// | ||
|
||
MlirType ContextCache::GetNoneType() { | ||
return mlirTypeParseGet(context_, toMlirStringRef("!torch.none")); | ||
} | ||
|
||
MlirType ContextCache::ConvertTypeProto(const onnx::TypeProto &tp) { | ||
if(tp.value_case() == onnx::TypeProto::VALUE_NOT_SET) | ||
return GetNoneType(); | ||
|
||
if (tp.has_tensor_type()) { | ||
// Convert Tensor TypeProto. | ||
const onnx::TypeProto_Tensor &tt = tp.tensor_type(); | ||
|
@@ -503,6 +513,22 @@ NodeImporter::NodeImporter(GraphInfo &graph_info, ContextCache &cc, | |
/*childLoc=*/{nullptr}); | ||
} | ||
|
||
void NodeImporter::ImportNoneConstant() { | ||
auto found_it = nv_map_.find(""); | ||
if (found_it != nv_map_.end()) | ||
return; | ||
|
||
// Create an empty node(i.e. a none val), and place it in nv_map_. | ||
// This function should be called _just once_, to avoid multiple | ||
// nodes producing the same none value. Once set, there is really | ||
// no need to put "" in nv_map again. | ||
MlirOperation op = createMlirOperationAtEnd( | ||
body_block_, "torch.constant.none", default_loc_, cc_.GetNoneType()); | ||
MlirValue nne = mlirOperationGetResult(op, 0); | ||
// Place into nv_map. | ||
nv_map_[""] = nne; | ||
} | ||
|
||
Status NodeImporter::DefineFunction(std::optional<std::string> name, | ||
MlirOperation *out_function_op) { | ||
const onnx::GraphProto &p = graph_info_.graph_proto(); | ||
|
@@ -613,15 +639,15 @@ void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) { | |
mlirStringAttrGet(context_, toMlirStringRef(m.producer_version()))); | ||
} | ||
|
||
Status NodeImporter::ImportAll() { | ||
Status NodeImporter::ImportAll(const onnxruntime::GraphViewer &gv) { | ||
// TODO: Consider pulling in initializers on demand since there can be so | ||
// much unused crap. | ||
for (auto it : graph_info_.initializer_map()) { | ||
if (failed(ImportInitializer(it.second))) | ||
return failure(); | ||
} | ||
for (auto it : graph_info_.graph_proto().node()) { | ||
if (failed(ImportNode(it))) | ||
if (failed(ImportNode(it, gv))) | ||
return failure(); | ||
} | ||
|
||
|
@@ -678,17 +704,19 @@ Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) { | |
return success(); | ||
} | ||
|
||
Status NodeImporter::ImportNode(const onnx::NodeProto &node) { | ||
Status NodeImporter::ImportNode(const onnx::NodeProto &node, | ||
const onnxruntime::GraphViewer &gv) { | ||
std::string_view op_type = node.op_type(); | ||
// Handle special-form op types that do not go down the generic path. | ||
if (op_type == "ConstantOfShape") { | ||
return ImportConstantOfShapeNode(node); | ||
} | ||
|
||
return ImportGeneralNode(node); | ||
return ImportGeneralNode(node, gv); | ||
} | ||
|
||
Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { | ||
Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node, | ||
const onnxruntime::GraphViewer &gv) { | ||
MlirLocation loc = | ||
node.name().empty() | ||
? mlirLocationUnknownGet(context_) | ||
|
@@ -712,7 +740,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { | |
std::vector<MlirType> output_types; | ||
for (auto &output_name : node.output()) { | ||
const onnx::TypeProto *type_proto = | ||
graph_info_.FindTypeProtoForName(output_name); | ||
gv.GetNodeArg(output_name)->TypeAsProto(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the main reason for passing the graph view? I'd like to revisit the need to use graph viewer, and whether one of the structures, such as Is the problem that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually yes, this is the main reason. I could very well store the whole graph viewer, but I was wary of the cost it would have in terms of memory. That could very well be premature optimisation, now that you mention this. The problem as I figure it out is that it could be possible (this is just an observation, and I might be wrong here, but my grounds for this belief are rather solid) that the GraphProto that we get from graph viewer does not reflect all the nodes that live in the graph. It would of course have been different if we were operating on an unmodified onnx protobuf obtained directly from the model, as we have in the case of the python importer. Do you think that passing graph viewer as an arg is bad style? I think this would warrant a change then. I will push that along with the next commit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This type info issue doesn't just happen for initializers, but a good number of other values as well. I wanted to be sure about this, so I inserted a check to inform if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me see if I understand this correctly. When an inference session is initialized for a I'm guessing this is the case, since we are getting the graph from a graph view rather than a If a graph view is our starting point for ORT, would it make more sense to just store the graph view and not the graph in the GraphInfo structure? Does graph view have all the same capabilities of a graph (e.g., the ability to read nodes and value info)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, that is correct. This is the essence of the matter.
It actually does, and more. The graph that we operate on currently is obtained from graph viewer, and thus we can think of graph viewer as a super view of the actual onnx graph.
This, unfortunately, requires some thought. All the processing that we do outside of internal classes like GraphInfo that works on graph viewer would now require handling inside these class and structs. In conclusion, storing graph viewer is actually a very viable option, though it will have to be unwrapped somewhere later down the execution path. |
||
if (!type_proto) | ||
return failure(); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to report what is failing for easier debugging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The errors should propagate through the function call stack. Thus the errors are usually reported at the last point of failure. If we report errors at each function in the error path, there will be too many messages and the actual problem might get lost
Your thoughts?