-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranspose.py
29 lines (22 loc) · 858 Bytes
/
transpose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf
from tensorflow.keras import backend as KerasBackend # noqa
def transpose(a, perm=None, name: str | None = None):
"""
Transposes parameter a, automatically coping with the rarefaction with the help of overloaded
functions of the Flow of Ten worlds.
Args:
a: tensor or Sparse sensor with rank k.
perm: permutation indices of size k.
name: operation name.
Returns:
Tensor or sparse tensor with rank k.
"""
keras_is_sparse_predicate = KerasBackend.is_sparse(a)
if keras_is_sparse_predicate:
transpose_op = tf.sparse.transpose
else:
transpose_op = tf.transpose
if perm is None:
# If we need to permutate: default value if the form is set to empty
perm = (1, 0)
return transpose_op(a, perm=perm, name=name)