forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 1
/
graph_transformer.h
70 lines (56 loc) · 2.54 KB
/
graph_transformer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <unordered_set>
#include "core/common/common.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/graph_transformer_level.h"
namespace onnxruntime {
/**
@class GraphTransformer
The interface for in-place transformation of a Graph.
*/
class GraphTransformer {
public:
GraphTransformer(const std::string& name, const std::unordered_set<std::string>& compatible_execution_providers = {})
: name_(name), compatible_provider_types_(compatible_execution_providers) {
}
virtual ~GraphTransformer() = default;
/** Gets the name of this graph transformer. */
const std::string& Name() const noexcept {
return name_;
}
const std::unordered_set<std::string>& GetCompatibleExecutionProviders() const noexcept {
return compatible_provider_types_;
}
/** Apply the in-place transformation defined by this transformer to the provided Graph instance.
@param[out] modified Set to true if the Graph was modified.
@returns Status with success or error information.
*/
common::Status Apply(Graph& graph, bool& modified, const logging::Logger& logger) const;
protected:
/** Helper method to call ApplyImpl on any subgraphs in the Node. */
common::Status Recurse(Node& node, bool& modified, int graph_level, const logging::Logger& logger) const {
int subgraph_level = ++graph_level;
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
auto& subgraph = *entry.second;
ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level, logger));
}
return Status::OK();
}
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformer);
// Apply the transform to the graph.
// graph_level is 0 for the main graph, and is incremented when descending into the subgraph of a node.
// You MUST call Recurse for all valid Nodes in the graph to ensure any subgraphs in control flow nodes
// (Scan/If/Loop) are processed as well.
// You should avoid calling Graph::Resolve in ApplyImpl unless you are 100% sure it's required. In most cases
// the call to Graph::Resolve in Apply prior to ApplyImpl being called, and after ApplyImpl fore the main graph
// completes (if 'modified' is true) should suffice.
virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger)
const = 0;
const std::string name_;
const std::unordered_set<std::string> compatible_provider_types_;
};
} // namespace onnxruntime