Skip to content

Commit

Permalink
comply with ruff linter
Browse files Browse the repository at this point in the history
  • Loading branch information
kmario23 committed Oct 20, 2024
1 parent 6b23c01 commit be93f68
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
4 changes: 2 additions & 2 deletions pdebench/data_download/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ python visualize_pdes.py --pde_name "1d_reacdiff"
https://darus.uni-stuttgart.de/api/access/datafile/133110
# visualize
python visualize_pdes.py --pde_name "advection"
python visualize_pdes.py --pde_name "advection" --param 0.4
```

---
Expand All @@ -69,7 +69,7 @@ python visualize_pdes.py --pde_name "advection"
https://darus.uni-stuttgart.de/api/access/datafile/133136
# visualize
python visualize_pdes.py --pde_name "burgers"
python visualize_pdes.py --pde_name "burgers" --param 0.01
```

---
Expand Down
33 changes: 16 additions & 17 deletions pdebench/data_download/visualize_pdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def visualize_diff_sorp(path, seed=None):
num_samples = len(h5_file.keys())

# randomly choose a seed for picking a sample that will subsequently be visualized
rng = np.random.default_rng()
if not seed:
seed = np.random.randint(0, num_samples)
seed = rng.integers(low=0, high=num_samples, size=1).item()

# Ensure the seed number is defined
assert seed < num_samples, "Seed number too high!"
Expand All @@ -62,7 +63,6 @@ def visualize_diff_sorp(path, seed=None):
) # show an initial one first
else:
im = ax.plot(data[i].squeeze(), animated=True, color="blue")
ax.plot
ims.append([im[0]])

# Animate the plot
Expand All @@ -86,8 +86,9 @@ def visualize_2d_reacdiff(path, seed=None):
num_samples = len(h5_file.keys())

# randomly choose a seed for picking a sample that will subsequently be visualized
rng = np.random.default_rng()
if not seed:
seed = np.random.randint(0, num_samples)
seed = rng.integers(low=0, high=num_samples, size=1).item()

# Ensure the seed number is defined
assert seed < num_samples, "Seed number too high!"
Expand Down Expand Up @@ -134,8 +135,9 @@ def visualize_swe(path, seed=None):
num_samples = len(h5_file.keys())

# randomly choose a seed for picking a sample that will subsequently be visualized
rng = np.random.default_rng()
if not seed:
seed = np.random.randint(0, num_samples)
seed = rng.integers(low=0, high=num_samples, size=1).item()

# Ensure the seed number is defined
assert seed < num_samples, "Seed number too high!"
Expand Down Expand Up @@ -176,7 +178,7 @@ def visualize_burgers(path, param=None):
# Read the h5 file and store the data
if param is not None:
flnm = "1D_Burgers_Sols_Nu" + str(param) + ".hdf5"
assert os.path.isfile(path + flnm), "no such file! " + path + flnm
assert Path(path + flnm).is_file(), "no such file! " + path + flnm
else:
flnm = "1D_Burgers_Sols_Nu0.01.hdf5"

Expand All @@ -200,7 +202,6 @@ def visualize_burgers(path, param=None):
im = ax.plot(
xcrd, data[i].squeeze(), animated=True, color="blue"
) # show an initial one first
ax.plot
ims.append([im[0]])

# Animate the plot
Expand All @@ -222,7 +223,7 @@ def visualize_advection(path, param=None):
# Read the h5 file and store the data
if param is not None:
flnm = "1D_Advection_Sols_beta" + str(param) + ".hdf5"
assert os.path.isfile(path + flnm), "no such file! " + path + flnm
assert Path(path + flnm).is_file(), "no such file! " + path + flnm
else:
flnm = "1D_Advection_Sols_beta0.4.hdf5"

Expand All @@ -242,7 +243,6 @@ def visualize_advection(path, param=None):
im = ax.plot(xcrd, data[i].squeeze(), animated=True)
if i == 0:
ax.plot(xcrd, data[i].squeeze()) # show an initial one first
ax.plot
ims.append([im[0]])

# Animate the plot
Expand Down Expand Up @@ -275,7 +275,7 @@ def visualize_1d_cfd(path, param=None):
+ str(param[3])
+ "_Train.hdf5"
)
assert os.path.isfile(path + flnm), "no such file! " + path + flnm
assert Path(path + flnm).is_file(), "no such file! " + path + flnm
else:
flnm = "1D_CFD_Rand_Eta1.e-8_Zeta1.e-8_periodic_Train.hdf5"

Expand All @@ -296,7 +296,6 @@ def visualize_1d_cfd(path, param=None):
im = ax.plot(xcrd, dd[i].squeeze(), animated=True)
if i == 0:
ax.plot(xcrd, dd[i].squeeze()) # show an initial one first
ax.plot
ims.append([im[0]])

# Animate the plot
Expand Down Expand Up @@ -335,7 +334,7 @@ def visualize_2d_cfd(path, param=None):
+ str(param[5])
+ "_Train.hdf5"
)
assert os.path.isfile(path + flnm), "no such file! " + path + flnm
assert Path(path + flnm).is_file(), "no such file! " + path + flnm
else:
flnm = "2D_CFD_Rand_M0.1_Eta1e-8_Zeta1e-8_periodic_512_Train.hdf5"

Expand Down Expand Up @@ -379,7 +378,7 @@ def visualize_3d_cfd(path, param=None):
+ str(param[4])
+ "_Train.hdf5"
)
assert os.path.isfile(path + flnm), "no such file! " + path + flnm
assert Path(path + flnm).is_file(), "no such file! " + path + flnm
else:
flnm = "3D_CFD_Rand_M1.0_Eta1e-8_Zeta1e-8_periodic_Train.hdf5"

Expand Down Expand Up @@ -422,7 +421,7 @@ def visualize_darcy(path, param=None):
# Read the h5 file and store the data
if param is not None:
flnm = "2D_DarcyFlow_beta" + str(param) + "_Train.hdf5"
assert os.path.isfile(path + flnm), "no such file! " + path + flnm
assert Path(path + flnm).is_file(), "no such file! " + path + flnm
else:
flnm = "2D_DarcyFlow_beta1.0_Train.hdf5"

Expand Down Expand Up @@ -458,7 +457,7 @@ def visualize_1d_reacdiff(path, param=None):
if param is not None:
assert len(param) == 2, "param should include Nu and Rho as list"
flnm = "ReacDiff_Nu" + str(param[0]) + "_Rho" + str(param[1]) + ".hdf5"
assert os.path.isfile(path + flnm), "no such file! " + path + flnm
assert Path(path + flnm).is_file(), "no such file! " + path + flnm
else:
flnm = "ReacDiff_Nu1.0_Rho1.0.hdf5"

Expand All @@ -478,7 +477,6 @@ def visualize_1d_reacdiff(path, param=None):
im = ax.plot(xcrd, data[i].squeeze(), animated=True)
if i == 0:
ax.plot(xcrd, data[i].squeeze()) # show an initial one first
ax.plot
ims.append([im[0]])

# Animate the plot
Expand All @@ -498,7 +496,7 @@ def visualize_1d_reacdiff(path, param=None):
arg_parser.add_argument(
"--data_path",
type=str,
default=".",
default="./",
help="Path to the hdf5 data where the downloaded data reside",
)
arg_parser.add_argument(
Expand Down Expand Up @@ -546,4 +544,5 @@ def visualize_1d_reacdiff(path, param=None):
elif args.pde_name == "1d_reacdiff":
visualize_1d_reacdiff(args.data_path, args.params)
else:
raise ValueError("PDE name not recognized!")
errmsg = "PDE name not recognized!"
raise ValueError(errmsg)

0 comments on commit be93f68

Please sign in to comment.