diff --git a/pdebench/data_download/README.md b/pdebench/data_download/README.md index 8e619db..5be2a5d 100644 --- a/pdebench/data_download/README.md +++ b/pdebench/data_download/README.md @@ -57,7 +57,7 @@ python visualize_pdes.py --pde_name "1d_reacdiff" https://darus.uni-stuttgart.de/api/access/datafile/133110 # visualize -python visualize_pdes.py --pde_name "advection" +python visualize_pdes.py --pde_name "advection" --param 0.4 ``` --- @@ -69,7 +69,7 @@ python visualize_pdes.py --pde_name "advection" https://darus.uni-stuttgart.de/api/access/datafile/133136 # visualize -python visualize_pdes.py --pde_name "burgers" +python visualize_pdes.py --pde_name "burgers" --param 0.01 ``` --- diff --git a/pdebench/data_download/visualize_pdes.py b/pdebench/data_download/visualize_pdes.py index 838cc5f..b5fddae 100644 --- a/pdebench/data_download/visualize_pdes.py +++ b/pdebench/data_download/visualize_pdes.py @@ -39,8 +39,9 @@ def visualize_diff_sorp(path, seed=None): num_samples = len(h5_file.keys()) # randomly choose a seed for picking a sample that will subsequently be visualized + rng = np.random.default_rng() if not seed: - seed = np.random.randint(0, num_samples) + seed = rng.integers(low=0, high=num_samples, size=1).item() # Ensure the seed number is defined assert seed < num_samples, "Seed number too high!" @@ -62,7 +63,6 @@ def visualize_diff_sorp(path, seed=None): ) # show an initial one first else: im = ax.plot(data[i].squeeze(), animated=True, color="blue") - ax.plot ims.append([im[0]]) # Animate the plot @@ -86,8 +86,9 @@ def visualize_2d_reacdiff(path, seed=None): num_samples = len(h5_file.keys()) # randomly choose a seed for picking a sample that will subsequently be visualized + rng = np.random.default_rng() if not seed: - seed = np.random.randint(0, num_samples) + seed = rng.integers(low=0, high=num_samples, size=1).item() # Ensure the seed number is defined assert seed < num_samples, "Seed number too high!" @@ -134,8 +135,9 @@ def visualize_swe(path, seed=None): num_samples = len(h5_file.keys()) # randomly choose a seed for picking a sample that will subsequently be visualized + rng = np.random.default_rng() if not seed: - seed = np.random.randint(0, num_samples) + seed = rng.integers(low=0, high=num_samples, size=1).item() # Ensure the seed number is defined assert seed < num_samples, "Seed number too high!" @@ -176,7 +178,7 @@ def visualize_burgers(path, param=None): # Read the h5 file and store the data if param is not None: flnm = "1D_Burgers_Sols_Nu" + str(param) + ".hdf5" - assert os.path.isfile(path + flnm), "no such file! " + path + flnm + assert Path(path + flnm).is_file(), "no such file! " + path + flnm else: flnm = "1D_Burgers_Sols_Nu0.01.hdf5" @@ -200,7 +202,6 @@ def visualize_burgers(path, param=None): im = ax.plot( xcrd, data[i].squeeze(), animated=True, color="blue" ) # show an initial one first - ax.plot ims.append([im[0]]) # Animate the plot @@ -222,7 +223,7 @@ def visualize_advection(path, param=None): # Read the h5 file and store the data if param is not None: flnm = "1D_Advection_Sols_beta" + str(param) + ".hdf5" - assert os.path.isfile(path + flnm), "no such file! " + path + flnm + assert Path(path + flnm).is_file(), "no such file! " + path + flnm else: flnm = "1D_Advection_Sols_beta0.4.hdf5" @@ -242,7 +243,6 @@ def visualize_advection(path, param=None): im = ax.plot(xcrd, data[i].squeeze(), animated=True) if i == 0: ax.plot(xcrd, data[i].squeeze()) # show an initial one first - ax.plot ims.append([im[0]]) # Animate the plot @@ -275,7 +275,7 @@ def visualize_1d_cfd(path, param=None): + str(param[3]) + "_Train.hdf5" ) - assert os.path.isfile(path + flnm), "no such file! " + path + flnm + assert Path(path + flnm).is_file(), "no such file! " + path + flnm else: flnm = "1D_CFD_Rand_Eta1.e-8_Zeta1.e-8_periodic_Train.hdf5" @@ -296,7 +296,6 @@ def visualize_1d_cfd(path, param=None): im = ax.plot(xcrd, dd[i].squeeze(), animated=True) if i == 0: ax.plot(xcrd, dd[i].squeeze()) # show an initial one first - ax.plot ims.append([im[0]]) # Animate the plot @@ -335,7 +334,7 @@ def visualize_2d_cfd(path, param=None): + str(param[5]) + "_Train.hdf5" ) - assert os.path.isfile(path + flnm), "no such file! " + path + flnm + assert Path(path + flnm).is_file(), "no such file! " + path + flnm else: flnm = "2D_CFD_Rand_M0.1_Eta1e-8_Zeta1e-8_periodic_512_Train.hdf5" @@ -379,7 +378,7 @@ def visualize_3d_cfd(path, param=None): + str(param[4]) + "_Train.hdf5" ) - assert os.path.isfile(path + flnm), "no such file! " + path + flnm + assert Path(path + flnm).is_file(), "no such file! " + path + flnm else: flnm = "3D_CFD_Rand_M1.0_Eta1e-8_Zeta1e-8_periodic_Train.hdf5" @@ -422,7 +421,7 @@ def visualize_darcy(path, param=None): # Read the h5 file and store the data if param is not None: flnm = "2D_DarcyFlow_beta" + str(param) + "_Train.hdf5" - assert os.path.isfile(path + flnm), "no such file! " + path + flnm + assert Path(path + flnm).is_file(), "no such file! " + path + flnm else: flnm = "2D_DarcyFlow_beta1.0_Train.hdf5" @@ -458,7 +457,7 @@ def visualize_1d_reacdiff(path, param=None): if param is not None: assert len(param) == 2, "param should include Nu and Rho as list" flnm = "ReacDiff_Nu" + str(param[0]) + "_Rho" + str(param[1]) + ".hdf5" - assert os.path.isfile(path + flnm), "no such file! " + path + flnm + assert Path(path + flnm).is_file(), "no such file! " + path + flnm else: flnm = "ReacDiff_Nu1.0_Rho1.0.hdf5" @@ -478,7 +477,6 @@ def visualize_1d_reacdiff(path, param=None): im = ax.plot(xcrd, data[i].squeeze(), animated=True) if i == 0: ax.plot(xcrd, data[i].squeeze()) # show an initial one first - ax.plot ims.append([im[0]]) # Animate the plot @@ -498,7 +496,7 @@ def visualize_1d_reacdiff(path, param=None): arg_parser.add_argument( "--data_path", type=str, - default=".", + default="./", help="Path to the hdf5 data where the downloaded data reside", ) arg_parser.add_argument( @@ -546,4 +544,5 @@ def visualize_1d_reacdiff(path, param=None): elif args.pde_name == "1d_reacdiff": visualize_1d_reacdiff(args.data_path, args.params) else: - raise ValueError("PDE name not recognized!") + errmsg = "PDE name not recognized!" + raise ValueError(errmsg)