-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscatter_sum.py
36 lines (27 loc) · 1.57 KB
/
scatter_sum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import tensorflow as tf
from gns.scatter.mixed_mode_support import mixed_mode_support
@mixed_mode_support
def scatter_sum(messages, indices, n_nodes):
"""
Summarizes messages according to the segments defined by the "indexes" indexes, with
support for messages in single/disjoint mode (single/disjoint) (rank 2) and mixed mode (mixed) (rank 3).
The output data has the same rank as the input, with the dimension "nodes" changed to the value of the n_nodes parameter.
For single/disjoint mode, messages are expected to have the form:
`[n_messages, n_features]` and outputs should have the same form
`[n_nodes, n_features]`
For mixed mode, messages are expected to have the form
`[batch, n_messages, n_features]` and outputs should have the same form
`[batch, n_nodes, n_features]`
It is expected that the indexes will always be a 1-dimensional tensor of integers <n_nodes>, with
the form `[n_messages]`
For any missing index (i.e. any integer within 0 <= i < n_nodes that is not
displayed in indexes) the corresponding output will be the minimum possible value for the message type.
If this index i is negative, it is ignored during aggregation.
Args:
messages: two-dimensional (2D) or three-dimensional (3D) tensor
indices: one-dimensional tensor with indexes in the dimension of message nodes
n_nodes: measurement of output data by the dimension of nodes
Returns:
tensor with the same rank as messages
"""
return tf.math.unsorted_segment_sum(messages, indices, n_nodes)