From a494a81b4f08e13d6fb3457d7e527585a66357fb Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Wed, 17 Jul 2019 12:20:16 +0200 Subject: [PATCH] Correct computing offset of anchors. Check https://github.com/fizyr/keras-retinanet/issues/1073 for more information. --- keras_retinanet/backend/common.py | 18 ++++++--- keras_retinanet/layers/_misc.py | 19 ++++----- keras_retinanet/models/retinanet.py | 7 ++-- keras_retinanet/utils/anchors.py | 21 ++++++---- tests/backend/test_common.py | 61 ++++++++++++----------------- tests/layers/test_misc.py | 12 +++++- tests/utils/test_anchors.py | 26 ++++++++++++ 7 files changed, 99 insertions(+), 65 deletions(-) diff --git a/keras_retinanet/backend/common.py b/keras_retinanet/backend/common.py index 8f8dcc6e0..115528e9e 100644 --- a/keras_retinanet/backend/common.py +++ b/keras_retinanet/backend/common.py @@ -52,16 +52,22 @@ def bbox_transform_inv(boxes, deltas, mean=None, std=None): return pred_boxes -def shift(shape, stride, anchors): +def shift(image_shape, features_shape, stride, anchors): """ Produce shifted anchors based on shape of the map and stride size. Args - shape : Shape to shift the anchors over. - stride : Stride to shift the anchors with over the shape. - anchors: The anchors to apply at each location. + image_shape : Shape of the input image. + features_shape : Shape of the feature map. + stride : Stride to shift the anchors with over the image. + anchors : The anchors to apply at each location. """ - shift_x = (keras.backend.arange(0, shape[1], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride - shift_y = (keras.backend.arange(0, shape[0], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride + # compute the offset of the anchors based on the image shape and the feature map shape + # see https://github.com/fizyr/keras-retinanet/issues/1073 for more information + offset_x = keras.backend.cast((image_shape[1] - (features_shape[1] - 1) * stride), keras.backend.floatx()) / 2.0 + offset_y = keras.backend.cast((image_shape[0] - (features_shape[0] - 1) * stride), keras.backend.floatx()) / 2.0 + + shift_x = keras.backend.arange(0, features_shape[1], dtype=keras.backend.floatx()) * stride + offset_x + shift_y = keras.backend.arange(0, features_shape[0], dtype=keras.backend.floatx()) * stride + offset_y shift_x, shift_y = meshgrid(shift_x, shift_y) shift_x = keras.backend.reshape(shift_x, [-1]) diff --git a/keras_retinanet/layers/_misc.py b/keras_retinanet/layers/_misc.py index 72db2571c..700f2b7b2 100644 --- a/keras_retinanet/layers/_misc.py +++ b/keras_retinanet/layers/_misc.py @@ -58,28 +58,29 @@ def __init__(self, size, stride, ratios=None, scales=None, *args, **kwargs): super(Anchors, self).__init__(*args, **kwargs) def call(self, inputs, **kwargs): - features = inputs - features_shape = keras.backend.shape(features) + image, features = inputs + features_shape = keras.backend.shape(features) + image_shape = keras.backend.shape(image) # generate proposals from bbox deltas and shifted anchors if keras.backend.image_data_format() == 'channels_first': - anchors = backend.shift(features_shape[2:4], self.stride, self.anchors) + anchors = backend.shift(image_shape[2:4], features_shape[2:4], self.stride, self.anchors) else: - anchors = backend.shift(features_shape[1:3], self.stride, self.anchors) + anchors = backend.shift(image_shape[1:3], features_shape[1:3], self.stride, self.anchors) anchors = keras.backend.tile(keras.backend.expand_dims(anchors, axis=0), (features_shape[0], 1, 1)) return anchors def compute_output_shape(self, input_shape): - if None not in input_shape[1:]: + if None not in input_shape[1][1:]: if keras.backend.image_data_format() == 'channels_first': - total = np.prod(input_shape[2:4]) * self.num_anchors + total = np.prod(input_shape[1][2:4]) * self.num_anchors else: - total = np.prod(input_shape[1:3]) * self.num_anchors + total = np.prod(input_shape[1][1:3]) * self.num_anchors - return (input_shape[0], total, 4) + return (input_shape[1][0], total, 4) else: - return (input_shape[0], None, 4) + return (input_shape[1][0], None, 4) def get_config(self): config = super(Anchors, self).get_config() diff --git a/keras_retinanet/models/retinanet.py b/keras_retinanet/models/retinanet.py index ba75d4505..554a46183 100644 --- a/keras_retinanet/models/retinanet.py +++ b/keras_retinanet/models/retinanet.py @@ -207,11 +207,12 @@ def __build_pyramid(models, features): return [__build_model_pyramid(n, m, features) for n, m in models] -def __build_anchors(anchor_parameters, features): +def __build_anchors(anchor_parameters, image, features): """ Builds anchors for the shape of the features from FPN. Args anchor_parameters : Parameteres that determine how anchors are generated. + image : The image input tensor. features : The FPN features. Returns @@ -229,7 +230,7 @@ def __build_anchors(anchor_parameters, features): ratios=anchor_parameters.ratios, scales=anchor_parameters.scales, name='anchors_{}'.format(i) - )(f) for i, f in enumerate(features) + )([image, f]) for i, f in enumerate(features) ] return keras.layers.Concatenate(axis=1, name='anchors')(anchors) @@ -328,7 +329,7 @@ def retinanet_bbox( # compute the anchors features = [model.get_layer(p_name).output for p_name in ['P3', 'P4', 'P5', 'P6', 'P7']] - anchors = __build_anchors(anchor_params, features) + anchors = __build_anchors(anchor_params, model.inputs[0], features) # we expect the anchors, regression and classification values as first output regression = model.outputs[0] diff --git a/keras_retinanet/utils/anchors.py b/keras_retinanet/utils/anchors.py index 08007c02b..43af0c314 100644 --- a/keras_retinanet/utils/anchors.py +++ b/keras_retinanet/utils/anchors.py @@ -234,24 +234,29 @@ def anchors_for_shape( ratios=anchor_params.ratios, scales=anchor_params.scales ) - shifted_anchors = shift(image_shapes[idx], anchor_params.strides[idx], anchors) + shifted_anchors = shift(image_shape, image_shapes[idx], anchor_params.strides[idx], anchors) all_anchors = np.append(all_anchors, shifted_anchors, axis=0) return all_anchors -def shift(shape, stride, anchors): - """ Produce shifted anchors based on shape of the map and stride size. +def shift(image_shape, features_shape, stride, anchors): + """ Produce shifted anchors based on shape of the image, shape of the feature map and stride. Args - shape : Shape to shift the anchors over. - stride : Stride to shift the anchors with over the shape. - anchors: The anchors to apply at each location. + image_shape : Shape of the input image. + features_shape : Shape of the feature map. + stride : Stride to shift the anchors with over the image. + anchors : The anchors to apply at each location. """ + # compute the offset of the anchors based on the image shape and the feature map shape + # see https://github.com/fizyr/keras-retinanet/issues/1073 for more information + offset_x = (image_shape[1] - (features_shape[1] - 1) * stride) / 2.0 + offset_y = (image_shape[0] - (features_shape[0] - 1) * stride) / 2.0 # create a grid starting from half stride from the top left corner - shift_x = (np.arange(0, shape[1]) + 0.5) * stride - shift_y = (np.arange(0, shape[0]) + 0.5) * stride + shift_x = np.arange(0, features_shape[1]) * stride + offset_x + shift_y = np.arange(0, features_shape[0]) * stride + offset_y shift_x, shift_y = np.meshgrid(shift_x, shift_y) diff --git a/tests/backend/test_common.py b/tests/backend/test_common.py index a4f6c3b94..f87aa5b6c 100644 --- a/tests/backend/test_common.py +++ b/tests/backend/test_common.py @@ -62,8 +62,9 @@ def test_bbox_transform_inv(): def test_shift(): - shape = (2, 3) - stride = 8 + image_shape = (20, 20) + feature_shape = (2, 2) + stride = 8 anchors = np.array([ [-8, -8, 8, 8], @@ -75,49 +76,35 @@ def test_shift(): expected = [ # anchors for (0, 0) - [4 - 8, 4 - 8, 4 + 8, 4 + 8], - [4 - 16, 4 - 16, 4 + 16, 4 + 16], - [4 - 12, 4 - 12, 4 + 12, 4 + 12], - [4 - 12, 4 - 16, 4 + 12, 4 + 16], - [4 - 16, 4 - 12, 4 + 16, 4 + 12], + [6 - 8, 6 - 8, 6 + 8, 6 + 8], + [6 - 16, 6 - 16, 6 + 16, 6 + 16], + [6 - 12, 6 - 12, 6 + 12, 6 + 12], + [6 - 12, 6 - 16, 6 + 12, 6 + 16], + [6 - 16, 6 - 12, 6 + 16, 6 + 12], # anchors for (0, 1) - [12 - 8, 4 - 8, 12 + 8, 4 + 8], - [12 - 16, 4 - 16, 12 + 16, 4 + 16], - [12 - 12, 4 - 12, 12 + 12, 4 + 12], - [12 - 12, 4 - 16, 12 + 12, 4 + 16], - [12 - 16, 4 - 12, 12 + 16, 4 + 12], - - # anchors for (0, 2) - [20 - 8, 4 - 8, 20 + 8, 4 + 8], - [20 - 16, 4 - 16, 20 + 16, 4 + 16], - [20 - 12, 4 - 12, 20 + 12, 4 + 12], - [20 - 12, 4 - 16, 20 + 12, 4 + 16], - [20 - 16, 4 - 12, 20 + 16, 4 + 12], + [14 - 8, 6 - 8, 14 + 8, 6 + 8], + [14 - 16, 6 - 16, 14 + 16, 6 + 16], + [14 - 12, 6 - 12, 14 + 12, 6 + 12], + [14 - 12, 6 - 16, 14 + 12, 6 + 16], + [14 - 16, 6 - 12, 14 + 16, 6 + 12], # anchors for (1, 0) - [4 - 8, 12 - 8, 4 + 8, 12 + 8], - [4 - 16, 12 - 16, 4 + 16, 12 + 16], - [4 - 12, 12 - 12, 4 + 12, 12 + 12], - [4 - 12, 12 - 16, 4 + 12, 12 + 16], - [4 - 16, 12 - 12, 4 + 16, 12 + 12], + [6 - 8, 14 - 8, 6 + 8, 14 + 8], + [6 - 16, 14 - 16, 6 + 16, 14 + 16], + [6 - 12, 14 - 12, 6 + 12, 14 + 12], + [6 - 12, 14 - 16, 6 + 12, 14 + 16], + [6 - 16, 14 - 12, 6 + 16, 14 + 12], # anchors for (1, 1) - [12 - 8, 12 - 8, 12 + 8, 12 + 8], - [12 - 16, 12 - 16, 12 + 16, 12 + 16], - [12 - 12, 12 - 12, 12 + 12, 12 + 12], - [12 - 12, 12 - 16, 12 + 12, 12 + 16], - [12 - 16, 12 - 12, 12 + 16, 12 + 12], - - # anchors for (1, 2) - [20 - 8, 12 - 8, 20 + 8, 12 + 8], - [20 - 16, 12 - 16, 20 + 16, 12 + 16], - [20 - 12, 12 - 12, 20 + 12, 12 + 12], - [20 - 12, 12 - 16, 20 + 12, 12 + 16], - [20 - 16, 12 - 12, 20 + 16, 12 + 12], + [14 - 8, 14 - 8, 14 + 8, 14 + 8], + [14 - 16, 14 - 16, 14 + 16, 14 + 16], + [14 - 12, 14 - 12, 14 + 12, 14 + 12], + [14 - 12, 14 - 16, 14 + 12, 14 + 16], + [14 - 16, 14 - 12, 14 + 16, 14 + 12], ] - result = keras_retinanet.backend.shift(shape, stride, anchors) + result = keras_retinanet.backend.shift(image_shape, feature_shape, stride, anchors) result = keras.backend.eval(result) np.testing.assert_array_equal(result, expected) diff --git a/tests/layers/test_misc.py b/tests/layers/test_misc.py index a4364ef67..873d96ee6 100644 --- a/tests/layers/test_misc.py +++ b/tests/layers/test_misc.py @@ -30,12 +30,16 @@ def test_simple(self): scales=np.array([1], keras.backend.floatx()), ) + # create fake image input (only shape is used anyway) + image = np.zeros((1, 16, 16, 3), dtype=keras.backend.floatx()) + image = keras.backend.variable(image) + # create fake features input (only shape is used anyway) features = np.zeros((1, 2, 2, 1024), dtype=keras.backend.floatx()) features = keras.backend.variable(features) # call the Anchors layer - anchors = anchors_layer.call(features) + anchors = anchors_layer.call([image, features]) anchors = keras.backend.eval(anchors) # expected anchor values @@ -59,12 +63,16 @@ def test_mini_batch(self): scales=np.array([1], dtype=keras.backend.floatx()), ) + # create fake image input (only shape is used anyway) + image = np.zeros((2, 16, 16, 3), dtype=keras.backend.floatx()) + image = keras.backend.variable(image) + # create fake features input with batch_size=2 features = np.zeros((2, 2, 2, 1024), dtype=keras.backend.floatx()) features = keras.backend.variable(features) # call the Anchors layer - anchors = anchors_layer.call(features) + anchors = anchors_layer.call([image, features]) anchors = keras.backend.eval(anchors) # expected anchor values diff --git a/tests/utils/test_anchors.py b/tests/utils/test_anchors.py index a684d0300..54909ff62 100644 --- a/tests/utils/test_anchors.py +++ b/tests/utils/test_anchors.py @@ -167,3 +167,29 @@ def test_anchors_for_shape_values(): strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, ], decimal=6) + + +def test_anchors_for_shape_odd_input(): + pyramid_levels = [3] + image_shape = (20, 20) # this shape causes rounding errors when downsampling using convolutions + sizes = [32] + strides = [8] + ratios = np.array([1], keras.backend.floatx()) + scales = np.array([1], keras.backend.floatx()) + anchor_params = AnchorParameters(sizes, strides, ratios, scales) + + anchors = anchors_for_shape(image_shape, pyramid_levels = pyramid_levels, anchor_params = anchor_params) + + expected_anchors = np.array([ + [-14, -14, 18, 18], + [-6 , -14, 26, 18], + [2 , -14, 34, 18], + [-14, -6 , 18, 26], + [-6 , -6 , 26, 26], + [2 , -6 , 34, 26], + [-14, 2 , 18, 34], + [-6 , 2 , 26, 34], + [2 , 2 , 34, 34], + ]) + + np.testing.assert_equal(anchors, expected_anchors)