forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubgraph_matcher.h
53 lines (47 loc) · 1.94 KB
/
subgraph_matcher.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#pragma once
#include <torch/csrc/jit/ir.h>
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
/**
* \brief A structure describing a match of a pattern in a graph.
*
* The structure contains an anchor node, from which the match was found, and
* match-maps for nodes and values. A match-map specifies correspondance between
* nodes in the pattern graph (match-map keys) with nodes in the actual graph
* (match-map values). We keep such maps for both nodes and values.
*/
struct Match {
Node* anchor;
std::unordered_map<const Node*, Node*> nodes_map;
std::unordered_map<const Value*, Value*> values_map;
};
/**
* \brief Find all matches of a \p PATTERN in a \p GRAPH.
*
* The function returns a vector of match-descriptors (see description of
* `struct Match`).
*
* Matching rules:
* - Pattern graph must contain a single block.
* - Matched subgraphs do not span across different blocks.
* - No uses outside the match are allowed, except for Param and Return nodes.
* Basically, we're matching hammocks, not arbitrary subgraphs.
* - Pattern graph must return only one value (i.e. it must have a single
* node leading to return).
* - Nodes that are not used in computation of the return value in the pattern
* graph are ignored during matching (IOW, we're essentially performing DCE on
* the pattern).
* - Pattern graph nodes cannot alias. TODO: the check not implemented yet.
* - Aliasing nodes in the graph can not consitute a match (i.e. in all found
* matches no nodes in the subgraph alias with each other). TODO: the check not
* implemented yet.
* - The matcher will not mutate either the pattern graph or the matched graph,
* but the latter is taken as non-const so that Match may contain non-const
* pointers. This enables clients of this API to use Match to drive mutations.
*/
std::vector<Match> TORCH_API
findPatternMatches(const Graph& pattern, Graph& graph);
} // namespace jit
} // namespace torch