Skip to content

Commit

Permalink
Update dir.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PengfeiRen96 authored Oct 8, 2023
1 parent 6ae2e66 commit 329e381
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions models/dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# Joint Space Interaction + Project Joint Feature to Image Space
class Joint2BoneFeature(nn.Module):
def __init__(self, img_feat_dim, emd_dim, joint_dim, joint_num, feature_size, mano_pth, distance=1):
def __init__(self, img_feat_dim, emd_dim, joint_dim, joint_num, feature_size, mano_pth, root_joint, distance=1):
super(Joint2BoneFeature, self).__init__()
edge = get_sketch_setting()
adj = adj_mx_from_edges(joint_num, edge, sparse=False, eye=False)
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, img_feat_dim, emd_dim, joint_dim, joint_num, feature_size, ma
nn.Conv2d(img_feat_dim, img_feat_dim, 1)
)

self.regressor = RegressorOffset(joint_num * joint_dim, mano_pth)
self.regressor = RegressorOffset(joint_num * joint_dim, mano_pth, root_joint)

x = (torch.arange(feature_size) + 0.5)
y = (torch.arange(feature_size) + 0.5)
Expand Down Expand Up @@ -216,12 +216,12 @@ def forward(self, x):


class InitRegressor(nn.Module):
def __init__(self, feat_dim, mano_path):
def __init__(self, feat_dim, mano_path, root_joint):
super(InitRegressor, self).__init__()
self.mano_layer_right = ObmanManoLayer(root_rot_mode='6D', joint_rot_mode='axisang', use_pca=True, mano_root=mano_path,
side='right', ncomps=45, center_idx=9, flat_hand_mean=False, robust_rot=True)
side='right', ncomps=45, center_idx=root_joint, flat_hand_mean=False, robust_rot=True)
self.mano_layer_left = ObmanManoLayer(root_rot_mode='6D', joint_rot_mode='axisang', use_pca=True, mano_root=mano_path,
side='left', ncomps=45, center_idx=9, flat_hand_mean=False, robust_rot=True)
side='left', ncomps=45, center_idx=root_joint, flat_hand_mean=False, robust_rot=True)
self.fix_shape(self.mano_layer_left, self.mano_layer_right)

self.attention_left = nn.Sequential(
Expand Down Expand Up @@ -310,12 +310,12 @@ def fix_shape(self, mano_layer_left, mano_layer_right):


class RegressorOffset(nn.Module):
def __init__(self, feat_dim, mano_path):
def __init__(self, feat_dim, mano_path, root_joint):
super(RegressorOffset, self).__init__()
self.mano_layer_right = ObmanManoLayer(root_rot_mode='6D', joint_rot_mode='axisang', use_pca=True, mano_root=mano_path,
side='right', ncomps=45, center_idx=9, flat_hand_mean=False, robust_rot=True)
side='right', ncomps=45, center_idx=root_joint, flat_hand_mean=False, robust_rot=True)
self.mano_layer_left = ObmanManoLayer(root_rot_mode='6D', joint_rot_mode='axisang', use_pca=True, mano_root=mano_path,
side='left', ncomps=45, center_idx=9, flat_hand_mean=False, robust_rot=True)
side='left', ncomps=45, center_idx=root_joint, flat_hand_mean=False, robust_rot=True)
self.fix_shape(self.mano_layer_left, self.mano_layer_right)

mano_para_dim = 3 * 2 + 15 * 3 + 10 + 3
Expand Down Expand Up @@ -355,7 +355,6 @@ def forward(self, sampled_feat_l, sampled_feat_r, mano_para_l_init, mano_para_r_
pd_mesh_xyz_left, pd_joint_xyz_left = self.mano_layer_left(pd_mano_pose_left, pd_mano_beta_left)
pd_mesh_xyz_right, pd_joint_xyz_right = self.mano_layer_right(pd_mano_pose_right, pd_mano_beta_right)


pd_joint_uv_left = projection_batch_xy(pd_para_left[:, 0], pd_para_left[:, 1:], pd_joint_xyz_left)
pd_mesh_uv_left = projection_batch_xy(pd_para_left[:, 0], pd_para_left[:, 1:], pd_mesh_xyz_left)
pd_joint_uv_right = projection_batch_xy(pd_para_right[:, 0], pd_para_right[:, 1:], pd_joint_xyz_right)
Expand Down Expand Up @@ -388,18 +387,18 @@ def fix_shape(self, mano_layer_left, mano_layer_right):


class FusionJointInterIterDecoder(nn.Module):
def __init__(self, joint_num, mano_pth, inDim=[2048, 1024, 512, 256], fDim=[256, 256, 256, 256]):
def __init__(self, joint_num, mano_pth, root_joint, inDim=[2048, 1024, 512, 256], fDim=[256, 256, 256, 256]):
super(FusionJointInterIterDecoder, self).__init__()
self.up4 = nn.Upsample(scale_factor=2, mode='bilinear')
self.skip_layer4 = Residual(inDim[1], fDim[0])
self.fusion_layer4 = Residual(inDim[0] + fDim[0], fDim[1])
self.projecter_4 = Joint2BoneFeature(fDim[1], 128, 64, joint_num, 16, mano_pth, distance=1)
self.projecter_4 = Joint2BoneFeature(fDim[1], 128, 64, joint_num, 16, mano_pth, root_joint, distance=1)
self.enhance_layer4 = Residual(fDim[1] * 2, fDim[1])

self.up3 = nn.Upsample(scale_factor=2, mode='bilinear')
self.skip_layer3 = Residual(inDim[2], fDim[1])
self.fusion_layer3 = Residual(fDim[1] * 2, fDim[2])
self.projecter_3 = Joint2BoneFeature(fDim[2], 128, 64, joint_num, 32, mano_pth, distance=2)
self.projecter_3 = Joint2BoneFeature(fDim[2], 128, 64, joint_num, 32, mano_pth, root_joint, distance=2)
self.enhance_layer3 = Residual(fDim[2] * 2, fDim[2])

self.conv_final = nn.Sequential(
Expand Down Expand Up @@ -485,7 +484,7 @@ def forward(self, x, result_dict):


class DIR(nn.Module):
def __init__(self, joint_num, mano_path):
def __init__(self, joint_num, mano_path, root_joint=0):
super(DIR, self).__init__()
self.joint_num = joint_num
weights = ResNet50_Weights.IMAGENET1K_V2
Expand All @@ -499,8 +498,8 @@ def __init__(self, joint_num, mano_path):
self.backbone.load_state_dict(model_dict)

self.mesh_sample_num = joint_num
self.init_regressor = InitRegressor(self.backbone.inplanes, mano_path)
self.decoder = FusionJointInterIterDecoder(self.joint_num, mano_path)
self.init_regressor = InitRegressor(self.backbone.inplanes, mano_path, root_joint)
self.decoder = FusionJointInterIterDecoder(self.joint_num, mano_path, root_joint)

self.coord_weight = 10
self.dense_weight = 1
Expand Down Expand Up @@ -560,9 +559,6 @@ def forward(self, input, target, meta_info):
gt_mesh_normal_xyz_right = (gt_mesh_xyz_right - gt_center_right) / 0.15
gt_offset = (gt_center_right - gt_center_left) / 0.15

length_left_gt = torch.linalg.norm(gt_joint_normal_xyz_left[:, 9] - gt_joint_normal_xyz_left[:, 0], dim=-1)
length_right_gt = torch.linalg.norm(gt_joint_normal_xyz_right[:, 9] - gt_joint_normal_xyz_right[:, 0], dim=-1)

map_size = decode_list['seg'].size(-1)
gt_seg_id = target['seg'].cuda()
gt_dense = target['dense'].cuda()
Expand All @@ -580,13 +576,9 @@ def forward(self, input, target, meta_info):

joints_left_pred, joints_right_pred = out["pd_joint_xyz_left"] / 0.15, out["pd_joint_xyz_right"] / 0.15
meshs_left_pred, meshs_right_pred = out["pd_mesh_xyz_left"] / 0.15, out["pd_mesh_xyz_right"] / 0.15
length_left_pd = torch.linalg.norm(joints_left_pred[:, 9] - joints_left_pred[:, 0],dim=-1)
length_right_pd = torch.linalg.norm(joints_right_pred[:, 9] - joints_right_pred[:, 0],dim=-1)

scale_left = (length_left_gt / length_left_pd).unsqueeze(-1).unsqueeze(-1)
scale_right = (length_right_gt / length_right_pd).unsqueeze(-1).unsqueeze(-1)
joints_left_pred, joints_right_pred = joints_left_pred*scale_left, joints_right_pred*scale_right
meshs_left_pred, meshs_right_pred = meshs_left_pred*scale_left, meshs_right_pred*scale_right
joints_left_pred, joints_right_pred = joints_left_pred, joints_right_pred
meshs_left_pred, meshs_right_pred = meshs_left_pred, meshs_right_pred

loss['joint_left_xyz_%d' % (index)] = self.l1_loss(joints_left_pred, gt_joint_normal_xyz_left) * self.coord_weight
loss['joint_right_xyz_%d' % (index)] = self.l1_loss(joints_right_pred, gt_joint_normal_xyz_right) * self.coord_weight
Expand Down

0 comments on commit 329e381

Please sign in to comment.