diff --git a/pdebench/models/metrics.py b/pdebench/models/metrics.py index ab16704..5f87098 100644 --- a/pdebench/models/metrics.py +++ b/pdebench/models/metrics.py @@ -157,21 +157,24 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -def metric_func(pred, target, if_mean=True, Lx=1.0, Ly=1.0, Lz=1.0, iLow=4, iHigh=12): +def metric_func(pred, target, if_mean=True, Lx=1.0, Ly=1.0, Lz=1.0, iLow=4, iHigh=12, initial_step=1): """ code for calculate metrics discussed in the Brain-storming session RMSE, normalized RMSE, max error, RMSE at the boundaries, conserved variables, RMSE in Fourier space, temporal sensitivity """ pred, target = pred.to(device), target.to(device) # (batch, nx^i..., timesteps, nc) + # slice out `initial context` timesteps + pred = pred[..., initial_step:, :] + target = target[..., initial_step:, :] idxs = target.size() - if len(idxs) == 4: + if len(idxs) == 4: # 1D pred = pred.permute(0, 3, 1, 2) target = target.permute(0, 3, 1, 2) - if len(idxs) == 5: + if len(idxs) == 5: # 2D pred = pred.permute(0, 4, 1, 2, 3) target = target.permute(0, 4, 1, 2, 3) - elif len(idxs) == 6: + elif len(idxs) == 6: # 3D pred = pred.permute(0, 5, 1, 2, 3, 4) target = target.permute(0, 5, 1, 2, 3, 4) idxs = target.size() @@ -238,12 +241,12 @@ def metric_func(pred, target, if_mean=True, Lx=1.0, Ly=1.0, Lz=1.0, iLow=4, iHig err_BD_z = (pred[:, :, :, :, 0] - target[:, :, :, :, 0]) ** 2 err_BD_z += (pred[:, :, :, :, -1] - target[:, :, :, :, -1]) ** 2 err_BD = ( - torch.sum(err_BD_x.view([nb, -1, nt]), dim=-2) - + torch.sum(err_BD_y.view([nb, -1, nt]), dim=-2) - + torch.sum(err_BD_z.view([nb, -1, nt]), dim=-2) + torch.sum(err_BD_x.contiguous().view([nb, -1, nt]), dim=-2) + + torch.sum(err_BD_y.contiguous().view([nb, -1, nt]), dim=-2) + + torch.sum(err_BD_z.contiguous().view([nb, -1, nt]), dim=-2) ) err_BD = err_BD / (2 * nx * ny + 2 * ny * nz + 2 * nz * nx) - err_BD = torch.mean(torch.sqrt(err_BD), dim=0) + err_BD = torch.sqrt(err_BD) if len(idxs) == 4: # 1D nx = idxs[2] @@ -350,7 +353,7 @@ def metrics( _err_Max, _err_BD, _err_F, - ) = metric_func(pred, yy, if_mean=True, Lx=Lx, Ly=Ly, Lz=Lz) + ) = metric_func(pred, yy, if_mean=True, Lx=Lx, Ly=Ly, Lz=Lz, initial_step=initial_step) if itot == 0: err_RMSE, err_nRMSE, err_CSV, err_Max, err_BD, err_F = ( @@ -408,7 +411,7 @@ def metrics( _err_Max, _err_BD, _err_F, - ) = metric_func(pred, yy, if_mean=True, Lx=Lx, Ly=Ly, Lz=Lz) + ) = metric_func(pred, yy, if_mean=True, Lx=Lx, Ly=Ly, Lz=Lz, initial_step=initial_step) if itot == 0: err_RMSE, err_nRMSE, err_CSV, err_Max, err_BD, err_F = ( _err_RMSE,