Skip to content

Commit

Permalink
Linting for dilations pt 1
Browse files Browse the repository at this point in the history
  • Loading branch information
jezsadler committed Nov 17, 2023
1 parent 6d0cf77 commit 668f192
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 26 deletions.
8 changes: 4 additions & 4 deletions src/omlt/io/onnx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,12 @@ def _consume_conv_nodes(self, node, next_nodes):
if "pads" in attr:
pads = attr["pads"]
else:
pads = 2*(len(input_output_size)-1)*[0]
pads = 2 * (len(input_output_size) - 1) * [0]

if "dilations" in attr:
dilations = attr["dilations"]
else:
dilations = (len(input_output_size)-1)*[1]
dilations = (len(input_output_size) - 1) * [1]

# Other attributes are not supported
if attr["group"] != 1:
Expand All @@ -379,8 +379,8 @@ def _consume_conv_nodes(self, node, next_nodes):

# generate new nodes for the node output
padding = [
pads[i] + pads[i + len(input_output_size)-1]
for i in range(len(input_output_size)-1)
pads[i] + pads[i + len(input_output_size) - 1]
for i in range(len(input_output_size) - 1)
]
output_size = [out_channels]
for w, k, s, p in zip(input_output_size[1:], kernel_shape, strides, padding):
Expand Down
88 changes: 66 additions & 22 deletions src/omlt/neuralnet/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def dilations(self):
def dilated_kernel_shape(self):
"""Return the shape of the kernel after dilation"""
dilated_dims = [
self.dilations[i]*(self.kernel_shape[i]-1) + 1
self.dilations[i] * (self.kernel_shape[i] - 1) + 1
for i in range(len(self.kernel_shape))
]
return tuple(dilated_dims)
Expand Down Expand Up @@ -333,8 +333,7 @@ def kernel_index_with_input_indexes(self, out_d, out_r, out_c):
# as this could require using a partial kernel
# even though we loop over ALL kernel indexes.
if not all(
input_index[i] < self.input_size[i]
and input_index[i] >= 0
input_index[i] < self.input_size[i] and input_index[i] >= 0
for i in range(len(input_index))
):
continue
Expand Down Expand Up @@ -498,25 +497,70 @@ def __init__(
)
self.__kernel = kernel
if self.dilations != [1, 1]:
dilate_rows = np.hstack([
np.hstack([
np.hstack([
kernel[:, :, i, :].reshape((
kernel.shape[0], kernel.shape[1], 1, kernel.shape[3])),
np.zeros((
kernel.shape[0], kernel.shape[1], self.dilations[0] - 1, kernel.shape[3]))])
for i in range(kernel.shape[2]-1)]),
kernel[:, :, -1, :].reshape((kernel.shape[0], kernel.shape[1], 1, kernel.shape[3]))
])
dilate_kernel = np.dstack([
np.dstack([
np.dstack([
dilate_rows[:, :, :, i].reshape((
dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], 1)),
np.zeros((dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], self.dilations[1] - 1))])
for i in range(dilate_rows.shape[3]-1)]),
dilate_rows[:, :, :, -1].reshape((dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], 1))
])
dilate_rows = np.hstack(
[
np.hstack(
[
np.hstack(
[
kernel[:, :, i, :].reshape(
(
kernel.shape[0],
kernel.shape[1],
1,
kernel.shape[3]
)
),
np.zeros(
(
kernel.shape[0],
kernel.shape[1],
self.dilations[0] - 1,
kernel.shape[3]
)
)
]
)
for i in range(kernel.shape[2] - 1)
]
),
kernel[:, :, -1, :].reshape(
(kernel.shape[0], kernel.shape[1], 1, kernel.shape[3])
),
]
)
dilate_kernel = np.dstack(
[
np.dstack(
[
np.dstack(
[
dilate_rows[:, :, :, i].reshape(
(
dilate_rows.shape[0],
dilate_rows.shape[1],
dilate_rows.shape[2],
1
)
),
np.zeros(
(
dilate_rows.shape[0],
dilate_rows.shape[1],
dilate_rows.shape[2],
self.dilations[1] - 1
)
)
]
)
for i in range(dilate_rows.shape[3]-1)
]
),
dilate_rows[:, :, :, -1].reshape(
(dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], 1)
),
]
)
self.__dilated_kernel = dilate_kernel
else:
self.__dilated_kernel = kernel
Expand Down

0 comments on commit 668f192

Please sign in to comment.