Skip to content

Commit

Permalink
update pool tests
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Jan 22, 2025
1 parent 505a942 commit f97ca29
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 17 deletions.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ByJ����?��?aK@��&@���?�q�?̌s?;�)>wU�>�5�>5�?;��?8C?���>���>�?3h�?B�>8.^?�g%@�L$@&sy?TK�?�@i@')�?�@>�P�?k/�?�&�?���>|w@���?V]�>���?/�?˙�?���>GG�?Dv�?/�?X@�,�?�?���?�;�?��?���?qBf?�>k?�z?]�?"�?�S�>Q��>�i�>(?��+?�4?I�4?��U?�h�?~t�>���?�>�?rr?mHh?p�:?��:?��? ��?ܟ?�p:?�h~?��r?��?hӟ>%�?�E�?w�p?�r�?��?Cd
@�^@��?�E�?6`�?tp�?|R�?��?E�z?2~?,O<?
ByJ����?��?aK@��&@���?�q�?͌s?;�)>wU�>�5�>5�?;��?8C?���>���>�?3h�?B�>7.^?�g%@�L$@%sy?TK�?�@i@')�?�@>�P�?k/�?�&�?���>|w@���?V]�>���?/�?˙�?���>GG�?Dv�?/�?X@�,�?�?���?�;�?��?���?qBf?�>k?�z?]�?"�?�S�>P��>�i�>(?��+?�4?I�4?��U?�h�?~t�>���?�>�?rr?nHh?p�:?��:?��? ��?ܟ?�p:?�h~?��r?��?hӟ>%�?�E�?w�p?�r�?��?Cd
@�^@��?�E�?5`�?tp�?|R�?��?E�z?2~?,O<?
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
37 changes: 22 additions & 15 deletions onnx/reference/ops/op_pool_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,22 +218,29 @@ def pool(
],
):
window = padded[shape[0], shape[1]]
elements = []
for i in range(spatial_size):
# NOTE: The if condition is to avoid the case where the window is out of bound
# we need to avoid the pixels that are out of bound being included in the window
elements.extend(
num
for num in range(
strides[i] * shape[i + 2],
strides[i] * shape[i + 2] + (1 + (kernel[i] - 1) * dilations[i]),
dilations[i],
window_vals = np.array(
[
window[i]
for i in list(
itertools.product(
*[
[
pixel
for pixel in range(
strides[i] * shape[i + 2],
strides[i] * shape[i + 2]
+ (1 + (kernel[i] - 1) * dilations[i]),
dilations[i],
)
if pixel
< x_shape[i + 2] + pads[i] + pads[spatial_size + i]
]
for i in range(spatial_size)
]
)
)
if num < x_shape[i + 2] + pads[i] + pads[i + spatial_size]
)
window_vals = np.array(
[window[indices] for indices in itertools.product(elements)]
)
]
)

if pooling_type == "AVG":
f = np.average
Expand Down

0 comments on commit f97ca29

Please sign in to comment.