Skip to content

Commit

Permalink
Updated the variational model
Browse files Browse the repository at this point in the history
  • Loading branch information
jbhayet committed May 3, 2024
1 parent 3fa9890 commit 3fbf31f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 59 deletions.
40 changes: 6 additions & 34 deletions scripts/train_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from trajpred_unc.utils.constants import SUBDATASETS_NAMES
from trajpred_unc.utils.config import load_config,get_model_filename
from trajpred_unc.uncertainties.calibration import generate_uncertainty_evaluation_dataset
from trajpred_unc.uncertainties.calibration_utils import save_data_for_calibration
from trajpred_unc.uncertainties.calibration_utils import save_data_for_uncertainty_calibration

# Parser arguments
config = load_config("deterministic_variational_ethucy.yaml")
Expand Down Expand Up @@ -85,16 +85,12 @@ def main():

# For each element of the ensemble
for ind in range(config["train"]["num_mctrain"]):

if torch.cuda.is_available():
observations_vel = observations_vel.to(device)

predictions,__,sigmas = model.predict(observations_vel)

predictions,__,sigmas = model.predict(observations_vel,observations_abs)
# Plotting
plot_traj_world(predictions[ind_sample,:,:], observations_abs[ind_sample,:,:], target_abs[ind_sample,:,:], ax)
plot_cov_world(predictions[ind_sample,:,:],sigmas[ind_sample,:,:],observations_abs[ind_sample,:,:], ax)
plt.legend()
plt.title('Trajectory samples {}'.format(batch_idx))
plt.savefig(config["misc"]["plot_dir"]+"/pred_variational.pdf")
if config["misc"]["show_test"]:
Expand All @@ -108,37 +104,13 @@ def main():
draw_ellipse = True

#------------------ Generates sub-dataset for calibration evaluation ---------------------------
__,__,observations_abs_e,target_abs_e,predictions_e,sigmas_e = generate_uncertainty_evaluation_dataset(batched_test_data, model,config,device=device,type="variational")
__,__,observations_abs,target_abs,predictions,sigmas = generate_uncertainty_evaluation_dataset(batched_test_data, model,config,device=device,type="variational")
#---------------------------------------------------------------------------------------------------------------

# Testing
cont = 0
for batch_idx, (observations_vel_c,__,observations_abs_c,target_abs_c,__,__,__) in enumerate(batched_test_data):

predictions_c = []
sigmas_c = []
# Muestreamos con cada modelo
for ind in range(config["train"]["num_mctrain"]):

if torch.cuda.is_available():
observations_vel_c = observations_vel_c.to(device)

predictions, kl, sigmas = model.predict(observations_vel_c)

predictions_c.append(predictions)
sigmas_c.append(sigmas)

predictions_c = np.array(predictions_c)
sigmas_c = np.array(sigmas_c)
# Save these testing data for uncertainty calibration
pickle_filename = config["train"]["model_name"]+"_"+SUBDATASETS_NAMES[config["dataset"]["id_dataset"]][config["dataset"]["id_test"]]
save_data_for_uncertainty_calibration(pickle_filename,predictions,observations_abs,target_abs,sigmas,config["dataset"]["id_test"])

# Save these testing data for uncertainty calibration
pickle_filename = config["train"]["model_name"]+"_"+SUBDATASETS_NAMES[config["dataset"]["id_dataset"]][config["dataset"]["id_test"]]
#save_data_for_calibration(pickle_filename, tpred_samples, tpred_samples_full, data_test, data_test_full, target_test, target_test_full, sigmas_samples, sigmas_samples_full, config.id_test)
save_data_for_calibration(pickle_filename,predictions_c,predictions_e, observations_abs_c,observations_abs_e,target_abs_c,target_abs_e,sigmas_c,sigmas_e,config["dataset"]["id_test"])


# Solo se ejecuta para un batch
break


if __name__ == "__main__":
Expand Down
42 changes: 19 additions & 23 deletions trajpred_unc/models/bayesian_models_gaussian_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def decode(self, last_pos, hidden_state):
sigma_pos= dec[:,:,2:]
return pred_pos, sigma_pos, hidden_state, kl_sum

def forward(self, X, y, data_abs , target_abs, training=False, num_mc=1):
def forward(self,obs_vels,target_vels,obs_abs,target_abs,teacher_forcing=False, num_mc=1):

nll_loss = 0
output_ = []
Expand All @@ -185,50 +185,48 @@ def forward(self, X, y, data_abs , target_abs, training=False, num_mc=1):
for mc_run in range(num_mc):
kl_sum = 0
# Encode the past trajectory
last_pos, hidden_state, kl = self.encode(X)
last_vel, hidden_state, kl = self.encode(obs_vels)
kl_sum += kl

# Iterate for each time step
loss = 0
pred_traj = []
sigma_traj = []
pred_vels = []
sigma_vels = []

for i, target in enumerate(y.permute(1,0,2)):
for i, target_vel in enumerate(target_vels.permute(1,0,2)):
# Decode last position and hidden state into new position
pred_pos, sigma_pos, hidden_state, kl = self.decode(last_pos,hidden_state)
pred_vel, sigma_vel, hidden_state, kl = self.decode(last_vel,hidden_state)
if i==0:
kl_sum += kl
# Keep new position and variance
pred_traj.append(pred_pos)
sigma_traj.append(sigma_pos)
pred_vels.append(pred_vel)
sigma_vels.append(sigma_vel)
# Update the last position
if training:
last_pos = target.view(len(target), 1, -1)
if teacher_forcing:
last_vel = target_vel.view(len(target_vel), 1, -1)
else:
last_pos = pred_pos
last_vel = pred_vel

# Utilizamos la nueva funcion loss
means_traj = data_abs[:,-1,:] + torch.cat(pred_traj, dim=1).sum(1)
loss += Gaussian2DLikelihood(target_abs[:,i,:], means_traj, torch.cat(sigma_traj, dim=1), self.dt)
pred_abs = obs_abs[:,-1,:] + torch.mul(torch.cat(pred_vels, dim=1).sum(1),self.dt)
loss += Gaussian2DLikelihood(target_abs[:,i,:], pred_abs, torch.cat(sigma_vels, dim=1),self.dt)

# Concatenate the trajectories preds
pred_traj = torch.cat(pred_traj, dim=1)
pred_vels = torch.cat(pred_vels, dim=1)
nll_loss += loss/num_mc

# save to list
output_.append(pred_traj)
output_.append(pred_vels)
kl_.append(kl_sum)
pred = torch.mean(torch.stack(output_), dim=0)
kl_loss = torch.mean(torch.stack(kl_), dim=0)
# Calculate of nl loss
#nll_loss = self.loss_fun(pred, y)
# Concatenate the predictions and return
return pred, nll_loss, kl_loss

def predict(self, X, dim_pred= 1):
def predict(self, obs_vels, obs_pos, dim_pred= 12):
kl_sum = 0
# Encode the past trajectory
last_pos, hidden_state, kl = self.encode(X)
last_pos, hidden_state, kl = self.encode(obs_vels)
kl_sum += kl

pred_traj = []
Expand All @@ -248,9 +246,7 @@ def predict(self, X, dim_pred= 1):
last_pos = pred_pos

# Concatenate the predictions and return
pred_traj = torch.cumsum(torch.cat(pred_traj, dim=1), dim=1).detach().cpu().numpy()
sigma_traj= torch.cumsum(torch.cat(sigma_traj, dim=1), dim=1).detach().cpu().numpy()

pred_traj = self.dt*torch.cumsum(torch.cat(pred_traj, dim=1), dim=1).detach().cpu().numpy()+obs_pos[:,-1:,:].cpu().numpy()
sigma_traj= self.dt*self.dt*torch.cumsum(torch.cat(sigma_traj, dim=1), dim=1).detach().cpu().numpy()
# Concatenate the predictions and return
return pred_traj, kl_sum, sigma_traj
#return torch.cat(pred, dim=1).detach().cpu().numpy(), kl_sum, torch.cat(sigma, dim=1).detach().cpu().numpy()
4 changes: 2 additions & 2 deletions trajpred_unc/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def train_variational(model,device,train_data,val_data,config):
# Step 3. Compute the gradients, and update the parameters by
loss.backward()
optimizer.step()
logging.info("Training loss: {:6.3f}".format(error/total))
logging.info("Training variational loss: {:6.3f}".format(error/total))
list_loss_train.append(error/total)

# Validation
Expand All @@ -164,7 +164,7 @@ def train_variational(model,device,train_data,val_data,config):
error += loss.detach().item()
total += len(target_vel)

logging.info("Validation loss: {:6.3f}".format(error/total))
logging.info("Validation variational loss: {:6.3f}".format(error/total))
list_loss_val.append(error/total)
if (error/total)<min_val_error:
min_val_error = error/total
Expand Down

0 comments on commit 3fbf31f

Please sign in to comment.