-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtensor_shape.py
166 lines (138 loc) · 4.75 KB
/
tensor_shape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""Defines TensorShape class."""
import numpy as np
class TensorShape(object):
"""`TensorShape` is the wrapper of a tuple of positive integers (or Nones)
that specify the shape of a symbolic Tensor at graph construction time.
Integers indicate concrete size of a dimension, while None means unknown size.
The tuple can also be None, which means the rank (number of dimensions) of the
tensor is not known.
The "level" of a `TensorShape` indicates how much we know about the shape of
the `Tensor`. See method `level`.
"""
def __init__(self, raw_shape):
"""Constructor.
Args:
raw_shape (List[int]): a list of positive integers or Nones, the shape of
a `Tensor`.
"""
# rank is not known
if raw_shape is None:
self._raw_shape = None
self._ndims = None
# rank is known
else:
self._raw_shape = tuple(raw_shape)
self._ndims = len(raw_shape)
@property
def ndims(self):
"""number of dimensions (rank)."""
return self._ndims
@property
def raw_shape(self):
""""""
return self._raw_shape
@property
def level(self):
"""Level of the TensorShape:
0: unknown rank (num of dimensions)
1: rank is known (but size of some dimension is unknown)
2: full shape is known
"""
if self.ndims is None:
return 0
elif all([s is not None for s in self.raw_shape]):
return 2
else:
return 1
def _partial_size(self):
"""Compute the partial size.
The partial size is total number of elements from the dimensions whose size
is known.
"""
if self.level == 0:
return -1
else:
sizes = [s for s in self.raw_shape if s is not None]
if len(sizes):
return np.prod(sizes).astype("int32").item()
else:
return -1
def __repr__(self):
if self._raw_shape is None:
return "TensorShape(None)"
else:
return 'TensorShape([%s])' % ', '.join(map(str, self._raw_shape))
def _compatible_with(self, tensor_shape):
"""Checks if the `TensorShape` is compatible with another shape. Example:
`Shape(None, 1, 2, 3)` is compatible with `Shape(1, 1, 2, 3)`, while not
compatible with `Shape(1, 2, 2, 3)`.
Args:
tensor_shape: raw shape, i.e. a list (or tuple) of integers (or None),
or a `TensorShape` instance.
"""
if self.ndims is None or tensor_shape.ndims is None:
return True
if self.ndims != tensor_shape.ndims:
return False
return all([
d1 is None or d2 is None or d1 == d2
for d1, d2 in zip(self.raw_shape, tensor_shape.raw_shape)
])
def _broadcastable_with(self, tensor_shape):
"""Returns whether this `TensorShape` is broadcastable with `tensor_shape`.
Two shapes are broadcastable if for all pairs of corresponding of sizes
starting from the lowest dimension
* they are equal
* one of them is 1
* one of them is None (unknown)
"""
if (
self.ndims is None or tensor_shape.ndims is None or self.ndims == 0 or
tensor_shape.ndims == 0
):
return True
return all([
d1 is None or d2 is None or d1 == 1 or d2 == 1 or d1 == d2
for d1, d2 in zip(self.raw_shape[::-1], tensor_shape.raw_shape[::-1])
])
def _merge(self, tensor_shape, skip=[]):
"""Replace unknown dimension sizes in this `TensorShape` with concrete
sizes in `tensor_shape`.
Args:
tensor_shape (TensorShape): the other `TensorShape`.
skip (List[int]): list of dimension indices to skip.
"""
if self._compatible_with(tensor_shape):
if self.ndims is not None and tensor_shape.ndims is not None:
raw_shape = list(self._raw_shape)
for i, s in enumerate(raw_shape):
if i in skip:
continue
if raw_shape[i] is None and tensor_shape.raw_shape[i] is not None:
raw_shape[i] = tensor_shape._raw_shape[i]
self._raw_shape = tuple(raw_shape)
else:
raise ValueError(
f"Attempting to merge incompatible shapes: {self}, {tensor_shape}",
)
def _diff_at(self, tensor_shape):
"""Returns the indices of dimensions at which this `TensorShape` differs
with `tensor_shape`.
"""
axes = []
assert self._compatible_with(tensor_shape)
if self.ndims is not None and tensor_shape.ndims is not None:
for i, (d1, d2) in enumerate(zip(self, tensor_shape)):
if d1 is not None and d2 is not None and d1 != d2:
axes.append(i)
return axes
def __getitem__(self, k):
"""Allows for indexing and slicing. Example:
`Shape(1, 2, None)[1] == 2`, and
`Shape(1, 2, None)[1:]` == TensorShape([2, None])`.
"""
assert self.ndims is not None
if isinstance(k, slice):
return TensorShape(self.raw_shape[k])
else:
return self._raw_shape[k]