-
Notifications
You must be signed in to change notification settings - Fork 7
/
H36M_MTVAEPredModel.py
110 lines (91 loc) · 4.51 KB
/
H36M_MTVAEPredModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Helper functions for training Motion Transformation VAE on Human3.6M."""
import os
import numpy as np
import tensorflow as tf
import h36m_losses as losses
import H36M_BasePredModel as BasePredModel
from nets import h36m_mtvae_factory as model_factory
slim = tf.contrib.slim
class MTVAEPredModel(BasePredModel.BasePredModel):
"""Defines MTVAE Prediction Model."""
def __init__(self, params):
super(MTVAEPredModel, self).__init__(params)
def get_model_fn(self, is_training, use_prior, reuse):
params = self._params
model_fn = model_factory.get_model_fn(self._params, is_training,
use_prior, reuse)
return model_fn
def get_sample_fn(self, is_training, use_prior, reuse, output_length=None):
return model_factory.get_sample_fn(self._params, is_training,
use_prior, reuse, output_length)
def get_loss(self, step, inputs, outputs):
total_loss = tf.zeros(dtype=tf.float32, shape=[])
loss_dict = dict()
params = self._params
if hasattr(params, 'keypoint_weight') and \
(params.keypoint_weight > 0):
keypoint_loss = losses.get_keypoint_loss(
inputs, outputs, params.max_length, params.keypoint_weight)
loss_dict['post_keypoint_loss'] = keypoint_loss
total_loss += keypoint_loss
if hasattr(params, 'velocity_weight') and (params.velocity_weight > 0):
assert params.cycle_model
#
curr_velocity_weight = (params.velocity_weight - (params.velocity_weight - params.velocity_start_weight) * (params.velocity_decay_rate)**tf.to_float(step))
velocity_loss = losses.get_velocity_loss(
inputs['last_landmarks'], inputs['fut_landmarks'], outputs['fut_landmarks'],
inputs['fut_lens'], curr_velocity_weight * params.keypoint_weight,
'post_velocity_loss', params.velocity_length)
velocity_loss += losses.get_velocity_loss(
inputs['last_landmarks'], inputs['fut_landmarks'], outputs['cycle_fut_landmarks'],
inputs['fut_lens'], curr_velocity_weight * params.keypoint_weight,
'prior_velocity_loss', params.velocity_length)
loss_dict['velocity_loss'] = velocity_loss
total_loss += velocity_loss
if hasattr(params, 'kl_weight') and (params.kl_weight > 0):
curr_kl_weight = (params.kl_weight - (params.kl_weight - params.kl_start_weight) *
(params.kl_decay_rate)**tf.to_float(step))
kl_loss = losses.get_kl_loss(
inputs, outputs, curr_kl_weight, params.kl_tolerance)
loss_dict['kl_loss'] = kl_loss
total_loss += kl_loss
if hasattr(params, 'cycle_weight') and params.cycle_model and (params.cycle_weight > 0):
cycle_loss = losses.get_cycle_loss(
inputs, outputs, params.cycle_weight)
loss_dict['cycle_loss'] = cycle_loss
total_loss += cycle_loss
slim.summaries.add_scalar_summary(
total_loss, 'keypoint_mtvae_loss', prefix='losses')
return total_loss, loss_dict
def print_running_loss(self, global_step, loss_dict):
params = self._params
if params.keypoint_weight > 0:
norm_keypoint_loss = loss_dict['post_keypoint_loss'] / params.keypoint_weight
else:
norm_keypoint_loss = 0
if params.kl_weight > 0:
curr_kl_weight = (params.kl_weight - (params.kl_weight - params.kl_start_weight) *
(params.kl_decay_rate)**tf.to_float(global_step))
norm_kl_loss = loss_dict['kl_loss'] / curr_kl_weight
else:
norm_kl_loss = 0
if hasattr(params, 'velocity_weight') and params.velocity_weight > 0:
curr_velocity_weight = (params.velocity_weight - (params.velocity_weight - params.velocity_start_weight) * (params.velocity_decay_rate)**tf.to_float(global_step))
norm_velocity_loss = loss_dict['velocity_loss'] / (curr_velocity_weight * params.keypoint_weight)
else:
norm_velocity_loss = 0
if hasattr(params, 'cycle_weight') and params.cycle_weight > 0:
norm_cycle_loss = loss_dict['cycle_loss'] / params.cycle_weight
else:
norm_cycle_loss = 0
def print_loss(step, keypoint_loss, kl_loss, velocity_loss, cycle_loss):
print('[%06d]\t[Keypoint %.3f]\t[KL %.3f]\t[VF %.3f]\t[CYCLE %.3f]' % \
(step, keypoint_loss, kl_loss, velocity_loss, cycle_loss))
return 0
ret_tmp = tf.py_func(
func=print_loss,
inp=[global_step, norm_keypoint_loss, norm_kl_loss,
norm_velocity_loss, norm_cycle_loss],
Tout=[tf.int64], name='print_loss')[0]
ret_tmp = tf.to_int32(ret_tmp)
return ret_tmp