Skip to content

Commit

Permalink
Adding mapping for mean in the torch to onnx op naming dict in onnx_u…
Browse files Browse the repository at this point in the history
…tils. (#2624)

Signed-off-by: Sayanta Mukherjee <[email protected]>
  • Loading branch information
quic-ssayanta authored Dec 28, 2023
1 parent ab0ed37 commit 48e8781
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from keras.engine.functional import Functional # pylint: disable=import-error
from keras.engine.keras_tensor import KerasTensor # pylint: disable=import-error
from keras.layers.core.tf_op_layer import TFOpLambda # pylint: disable=import-error
from keras.layers.merging.base_merge import _Merge as MergeLayersParentClass # pylint: disable=ungrouped-imports
from keras.layers.merging.base_merge import _Merge as MergeLayersParentClass # pylint: disable=ungrouped-imports, import-error
else:
# Ignore pylint errors due to conditional imports
from tensorflow.python.keras.engine.base_layer_utils import is_subclassed # pylint: disable=ungrouped-imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@
elementwise_ops.Subtract: ['Sub'],
elementwise_ops.Tile: ['Tile'],
elementwise_ops.TopK: ['TopK'],
torchvision.ops.RoIPool: ['MaxRoiPool']
torchvision.ops.RoIPool: ['MaxRoiPool'],
elementwise_ops.Mean: ['ReduceMean']
}

# Maps pytorch functional op string names to corresponding onnx types.
Expand Down

0 comments on commit 48e8781

Please sign in to comment.