diff --git a/menpofit/transform/modeldriven.py b/menpofit/transform/modeldriven.py index 3d35789..ddf8d00 100644 --- a/menpofit/transform/modeldriven.py +++ b/menpofit/transform/modeldriven.py @@ -262,7 +262,13 @@ def d_dp(self, points): # dW_dl: n_points x (n_dims) x n_centres x n_dims # dX_dp: (n_points x n_dims) x n_params - dW_dp = np.einsum('ild, lpd -> ipd', self.dW_dl, dX_dp) + + + # The following is equivalent to + # np.einsum('ild, lpd -> ipd', self.dW_dl, dX_dp) + dW_dp = np.tensordot(self.dW_dl, dX_dp, (1, 0)) + dW_dp = dW_dp.diagonal(axis1=3, axis2=1) + # dW_dp: n_points x n_params x n_dims return dW_dp