Skip to content

Commit

Permalink
fix metrics addressing #41 and more
Browse files Browse the repository at this point in the history
  • Loading branch information
kmario23 committed Oct 21, 2024
1 parent 1a83119 commit 9f3ca2b
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions pdebench/models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,24 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def metric_func(pred, target, if_mean=True, Lx=1.0, Ly=1.0, Lz=1.0, iLow=4, iHigh=12):
def metric_func(pred, target, if_mean=True, Lx=1.0, Ly=1.0, Lz=1.0, iLow=4, iHigh=12, initial_step=1):
"""
code for calculate metrics discussed in the Brain-storming session
RMSE, normalized RMSE, max error, RMSE at the boundaries, conserved variables, RMSE in Fourier space, temporal sensitivity
"""
pred, target = pred.to(device), target.to(device)
# (batch, nx^i..., timesteps, nc)
# slice out `initial context` timesteps
pred = pred[..., initial_step:, :]
target = target[..., initial_step:, :]
idxs = target.size()
if len(idxs) == 4:
if len(idxs) == 4: # 1D
pred = pred.permute(0, 3, 1, 2)
target = target.permute(0, 3, 1, 2)
if len(idxs) == 5:
if len(idxs) == 5: # 2D
pred = pred.permute(0, 4, 1, 2, 3)
target = target.permute(0, 4, 1, 2, 3)
elif len(idxs) == 6:
elif len(idxs) == 6: # 3D
pred = pred.permute(0, 5, 1, 2, 3, 4)
target = target.permute(0, 5, 1, 2, 3, 4)
idxs = target.size()
Expand Down Expand Up @@ -238,12 +241,12 @@ def metric_func(pred, target, if_mean=True, Lx=1.0, Ly=1.0, Lz=1.0, iLow=4, iHig
err_BD_z = (pred[:, :, :, :, 0] - target[:, :, :, :, 0]) ** 2
err_BD_z += (pred[:, :, :, :, -1] - target[:, :, :, :, -1]) ** 2
err_BD = (
torch.sum(err_BD_x.view([nb, -1, nt]), dim=-2)
+ torch.sum(err_BD_y.view([nb, -1, nt]), dim=-2)
+ torch.sum(err_BD_z.view([nb, -1, nt]), dim=-2)
torch.sum(err_BD_x.contiguous().view([nb, -1, nt]), dim=-2)
+ torch.sum(err_BD_y.contiguous().view([nb, -1, nt]), dim=-2)
+ torch.sum(err_BD_z.contiguous().view([nb, -1, nt]), dim=-2)
)
err_BD = err_BD / (2 * nx * ny + 2 * ny * nz + 2 * nz * nx)
err_BD = torch.mean(torch.sqrt(err_BD), dim=0)
err_BD = torch.sqrt(err_BD)

if len(idxs) == 4: # 1D
nx = idxs[2]
Expand Down Expand Up @@ -350,7 +353,7 @@ def metrics(
_err_Max,
_err_BD,
_err_F,
) = metric_func(pred, yy, if_mean=True, Lx=Lx, Ly=Ly, Lz=Lz)
) = metric_func(pred, yy, if_mean=True, Lx=Lx, Ly=Ly, Lz=Lz, initial_step=initial_step)

if itot == 0:
err_RMSE, err_nRMSE, err_CSV, err_Max, err_BD, err_F = (
Expand Down Expand Up @@ -408,7 +411,7 @@ def metrics(
_err_Max,
_err_BD,
_err_F,
) = metric_func(pred, yy, if_mean=True, Lx=Lx, Ly=Ly, Lz=Lz)
) = metric_func(pred, yy, if_mean=True, Lx=Lx, Ly=Ly, Lz=Lz, initial_step=initial_step)
if itot == 0:
err_RMSE, err_nRMSE, err_CSV, err_Max, err_BD, err_F = (
_err_RMSE,
Expand Down

0 comments on commit 9f3ca2b

Please sign in to comment.