Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model modifications #53

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions models/frustum_pointnets_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def get_model(point_cloud, one_hot_vec, is_training, bn_decay=None):
# select masked points and translate to masked points' centroid
object_point_cloud_xyz, mask_xyz_mean, end_points = \
point_cloud_masking(point_cloud, logits, end_points)

end_points['object_point_cloud_xyz'] = object_point_cloud_xyz
end_points['mask_xyz_mean'] = mask_xyz_mean

# T-Net and coordinate translation
center_delta, end_points = get_center_regression_net(\
Expand Down
5 changes: 4 additions & 1 deletion models/frustum_pointnets_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_3d_box_estimation_v2_net(object_point_cloud, one_hot_vec,
is_training=is_training, bn_decay=bn_decay, scope='ssg-layer3')

# Fully connected layers
net = tf.reshape(l3_points, [batch_size, -1])
net = tf.contrib.layers.flatten(l3_points)
net = tf.concat([net, one_hot_vec], axis=1)
net = tf_util.fully_connected(net, 512, bn=True,
is_training=is_training, scope='fc1', bn_decay=bn_decay)
Expand Down Expand Up @@ -147,6 +147,9 @@ def get_model(point_cloud, one_hot_vec, is_training, bn_decay=None):
# select masked points and translate to masked points' centroid
object_point_cloud_xyz, mask_xyz_mean, end_points = \
point_cloud_masking(point_cloud, logits, end_points)

end_points['object_point_cloud_xyz'] = object_point_cloud_xyz
end_points['mask_xyz_mean'] = mask_xyz_mean

# T-Net and coordinate translation
center_delta, end_points = get_center_regression_net(\
Expand Down
13 changes: 6 additions & 7 deletions models/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def get_box3d_corners_helper(centers, headings, sizes):
#print x_corners, y_corners, z_corners
c = tf.cos(headings)
s = tf.sin(headings)
ones = tf.ones([N], dtype=tf.float32)
zeros = tf.zeros([N], dtype=tf.float32)
ones = tf.ones(tf.shape(headings), dtype=tf.float32)
zeros = tf.zeros(tf.shape(headings), dtype=tf.float32)
row1 = tf.stack([c,zeros,s], axis=1) # (N,3)
row2 = tf.stack([zeros,ones,zeros], axis=1)
row3 = tf.stack([-s,zeros,c], axis=1)
Expand All @@ -106,16 +106,15 @@ def get_box3d_corners(center, heading_residuals, size_residuals):
heading_bin_centers = tf.constant(np.arange(0,2*np.pi,2*np.pi/NUM_HEADING_BIN), dtype=tf.float32) # (NH,)
headings = heading_residuals + tf.expand_dims(heading_bin_centers, 0) # (B,NH)

mean_sizes = tf.expand_dims(tf.constant(g_mean_size_arr, dtype=tf.float32), 0) + size_residuals # (B,NS,1)
mean_sizes = tf.expand_dims(tf.constant(g_mean_size_arr, dtype=tf.float32), 0) # (B,NS,1)
sizes = mean_sizes + size_residuals # (B,NS,3)
sizes = tf.tile(tf.expand_dims(sizes,1), [1,NUM_HEADING_BIN,1,1]) # (B,NH,NS,3)
headings = tf.tile(tf.expand_dims(headings,-1), [1,1,NUM_SIZE_CLUSTER]) # (B,NH,NS)
centers = tf.tile(tf.expand_dims(tf.expand_dims(center,1),1), [1,NUM_HEADING_BIN, NUM_SIZE_CLUSTER,1]) # (B,NH,NS,3)

N = batch_size*NUM_HEADING_BIN*NUM_SIZE_CLUSTER
corners_3d = get_box3d_corners_helper(tf.reshape(centers, [N,3]), tf.reshape(headings, [N]), tf.reshape(sizes, [N,3]))
corners_3d = get_box3d_corners_helper(tf.reshape(centers, [-1,3]), tf.reshape(headings, [-1,]), tf.reshape(sizes, [-1,3]))

return tf.reshape(corners_3d, [batch_size, NUM_HEADING_BIN, NUM_SIZE_CLUSTER, 8, 3])
return tf.reshape(corners_3d, [-1, NUM_HEADING_BIN, NUM_SIZE_CLUSTER, 8, 3])


def huber_loss(error, delta):
Expand Down Expand Up @@ -152,7 +151,7 @@ def parse_output_to_tensors(output, end_points):
size_residuals_normalized = tf.slice(output,
[0,3+NUM_HEADING_BIN*2+NUM_SIZE_CLUSTER], [-1,NUM_SIZE_CLUSTER*3])
size_residuals_normalized = tf.reshape(size_residuals_normalized,
[batch_size, NUM_SIZE_CLUSTER, 3]) # BxNUM_SIZE_CLUSTERx3
[-1, NUM_SIZE_CLUSTER, 3]) # BxNUM_SIZE_CLUSTERx3
end_points['size_scores'] = size_scores
end_points['size_residuals_normalized'] = size_residuals_normalized
end_points['size_residuals'] = size_residuals_normalized * \
Expand Down
6 changes: 3 additions & 3 deletions models/pointnet_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def sample_and_group_all(xyz, points, use_xyz=True):
'''
batch_size = xyz.get_shape()[0].value
nsample = xyz.get_shape()[1].value
new_xyz = tf.constant(np.tile(np.array([0,0,0]).reshape((1,1,3)), (batch_size,1,1)),dtype=tf.float32) # (batch_size, 1, 3)
idx = tf.constant(np.tile(np.array(range(nsample)).reshape((1,1,nsample)), (batch_size,1,1)))
grouped_xyz = tf.reshape(xyz, (batch_size, 1, nsample, 3)) # (batch_size, npoint=1, nsample, 3)
new_xyz = tf.zeros(tf.shape(xyz[:,:1,:]),dtype=tf.float32) # (batch_size, 1, 3)
idx = tf.tile(np.array(range(nsample)).reshape((1,1,nsample)), tf.shape(xyz[:,:1,:1]))
grouped_xyz = tf.reshape(xyz, (-1, 1, nsample, 3)) # (batch_size, npoint=1, nsample, 3)
if points is not None:
if use_xyz:
new_points = tf.concat([xyz, points], axis=2) # (batch_size, 16, 259)
Expand Down