Skip to content

Commit

Permalink
Merge pull request mala-project#595 from RandomDefaultUser/temp_file_qe
Browse files Browse the repository at this point in the history
Implemented `tempfile` for TEM
  • Loading branch information
RandomDefaultUser authored Oct 25, 2024
2 parents 33e319b + db86eb9 commit fd221a6
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 44 deletions.
5 changes: 3 additions & 2 deletions external_modules/total_energy_module/total_energy.f90
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SUBROUTINE initialize(y_planes_in, calculate_eigts_in)
SUBROUTINE initialize(file_name, y_planes_in, calculate_eigts_in)
!----------------------------------------------------------------------------
! Derived from Quantum Espresso code
!! author: Paolo Giannozzi
Expand Down Expand Up @@ -29,6 +29,7 @@ SUBROUTINE initialize(y_planes_in, calculate_eigts_in)
LOGICAL, INTENT(IN), OPTIONAL :: calculate_eigts_in
LOGICAL :: calculate_eigts = .false.
INTEGER, INTENT(IN), OPTIONAL :: y_planes_in
CHARACTER(len=256), INTENT(IN) :: file_name
! Parse optional arguments.
IF (PRESENT(calculate_eigts_in)) THEN
calculate_eigts = calculate_eigts_in
Expand All @@ -45,7 +46,7 @@ SUBROUTINE initialize(y_planes_in, calculate_eigts_in)
!
CALL environment_start ( 'PWSCF' )
!
CALL read_input_file ('PW', 'mala.pw.scf.in' )
CALL read_input_file ('PW', file_name )
CALL run_pwscf_setup ( exit_status, calculate_eigts)

print *, "Setup completed"
Expand Down
25 changes: 12 additions & 13 deletions mala/descriptors/bispectrum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Bispectrum descriptor class."""

import os
import tempfile

import ase
import ase.io
Expand Down Expand Up @@ -963,21 +962,21 @@ def __compute_ui(self, nr_atoms, atoms_cutoff, distances_cutoff, grid):
)
jju1 += 1
if jju_outer in self.__index_u1_symmetry_pos:
ulist_r_ij[
:, self.__index_u1_symmetry_pos[jju2]
] = ulist_r_ij[:, self.__index_u_symmetry_pos[jju2]]
ulist_i_ij[
:, self.__index_u1_symmetry_pos[jju2]
] = -ulist_i_ij[:, self.__index_u_symmetry_pos[jju2]]
ulist_r_ij[:, self.__index_u1_symmetry_pos[jju2]] = (
ulist_r_ij[:, self.__index_u_symmetry_pos[jju2]]
)
ulist_i_ij[:, self.__index_u1_symmetry_pos[jju2]] = (
-ulist_i_ij[:, self.__index_u_symmetry_pos[jju2]]
)
jju2 += 1

if jju_outer in self.__index_u1_symmetry_neg:
ulist_r_ij[
:, self.__index_u1_symmetry_neg[jju3]
] = -ulist_r_ij[:, self.__index_u_symmetry_neg[jju3]]
ulist_i_ij[
:, self.__index_u1_symmetry_neg[jju3]
] = ulist_i_ij[:, self.__index_u_symmetry_neg[jju3]]
ulist_r_ij[:, self.__index_u1_symmetry_neg[jju3]] = (
-ulist_r_ij[:, self.__index_u_symmetry_neg[jju3]]
)
ulist_i_ij[:, self.__index_u1_symmetry_neg[jju3]] = (
ulist_i_ij[:, self.__index_u_symmetry_neg[jju3]]
)
jju3 += 1

# This emulates add_uarraytot.
Expand Down
19 changes: 11 additions & 8 deletions mala/descriptors/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,33 +228,36 @@ def setup_lammps_tmp_files(self, lammps_type, outdir):
Type of descriptor calculation (e.g. bgrid for bispectrum)
outdir: str
Directory where lammps files are kept
Returns
-------
None
"""
if get_rank() == 0:
prefix_inp_str = "lammps_" + lammps_type + "_input"
prefix_log_str = "lammps_" + lammps_type + "_log"
lammps_tmp_input_file=tempfile.NamedTemporaryFile(
lammps_tmp_input_file = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix_inp_str, suffix="_.tmp", dir=outdir
)
self.lammps_temporary_input = lammps_tmp_input_file.name
lammps_tmp_input_file.close()

lammps_tmp_log_file=tempfile.NamedTemporaryFile(
lammps_tmp_log_file = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix_log_str, suffix="_.tmp", dir=outdir
)
self.lammps_temporary_log = lammps_tmp_log_file.name
lammps_tmp_log_file.close()
else:
self.lammps_temporary_input=None
self.lammps_temporary_log=None
self.lammps_temporary_input = None
self.lammps_temporary_log = None

if self.parameters._configuration["mpi"]:
self.lammps_temporary_input = get_comm().bcast(self.lammps_temporary_input, root=0)
self.lammps_temporary_log = get_comm().bcast(self.lammps_temporary_log, root=0)

self.lammps_temporary_input = get_comm().bcast(
self.lammps_temporary_input, root=0
)
self.lammps_temporary_log = get_comm().bcast(
self.lammps_temporary_log, root=0
)

# Calculations
##############
Expand Down
27 changes: 11 additions & 16 deletions mala/interfaces/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ase.calculators.calculator import Calculator, all_changes

from mala import Parameters, Network, DataHandler, Predictor, LDOS
from mala.common.parallelizer import barrier, parallel_warn
from mala.common.parallelizer import barrier, parallel_warn, get_rank, get_comm


class MALA(Calculator):
Expand Down Expand Up @@ -154,28 +154,23 @@ def calculate(
# Get the LDOS from the NN.
ldos = self.predictor.predict_for_atoms(atoms)

# forces = np.zeros([len(atoms), 3], dtype=np.float64)

# If an MPI environment is detected, ASE will use it for writing.
# Therefore we have to do this before forking.
self.data_handler.target_calculator.write_tem_input_file(
atoms,
self.data_handler.target_calculator.qe_input_data,
self.data_handler.target_calculator.qe_pseudopotentials,
self.data_handler.target_calculator.grid_dimensions,
self.data_handler.target_calculator.kpoints,
)

# Use the LDOS determined DOS and density to get energy and forces.
ldos_calculator: LDOS = self.data_handler.target_calculator

ldos_calculator.read_from_array(ldos)
self.results["energy"] = ldos_calculator.total_energy
energy, self.last_energy_contributions = (
ldos_calculator.get_total_energy(return_energy_contributions=True)
)
self.last_energy_contributions = (
ldos_calculator._density_calculator.total_energy_contributions.copy()
)
self.last_energy_contributions["e_band"] = ldos_calculator.band_energy
self.last_energy_contributions["e_entropy_contribution"] = (
ldos_calculator.entropy_contribution
)
barrier()

# Use the LDOS determined DOS and density to get energy and forces.
self.results["energy"] = energy
# forces = np.zeros([len(atoms), 3], dtype=np.float64)
# if "forces" in properties:
# self.results["forces"] = forces

Expand Down
25 changes: 22 additions & 3 deletions mala/targets/density.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Electronic density calculation class."""

import os.path
import time

from ase.units import Rydberg, Bohr, m
Expand All @@ -16,6 +17,8 @@
parallel_warn,
barrier,
get_size,
get_comm,
get_rank,
)
from mala.targets.target import Target
from mala.targets.cube_parser import read_cube, write_cube
Expand Down Expand Up @@ -406,7 +409,7 @@ def read_from_cube(self, path, units="1/Bohr^3", **kwargs):
printout("Reading density from .cube file ", path, min_verbosity=0)
# automatically convert units if they are None since cube files take atomic units
if units is None:
units="1/Bohr^3"
units = "1/Bohr^3"
if units != "1/Bohr^3":
printout(
"The expected units for the density from cube files are 1/Bohr^3\n"
Expand Down Expand Up @@ -960,12 +963,14 @@ def __setup_total_energy_module(
else:
kpoints = self.kpoints

self.write_tem_input_file(
tem_input_name = self.write_tem_input_file(
atoms_Angstrom,
qe_input_data,
qe_pseudopotentials,
self.grid_dimensions,
kpoints,
get_comm(),
get_rank(),
)

# initialize the total energy module.
Expand All @@ -984,8 +989,22 @@ def __setup_total_energy_module(
)
barrier()
t0 = time.perf_counter()
te.initialize(self.y_planes)

# We have to make sure we have the correct format for the file.
# QE expects the file without a path, and with a fixed length.
# I chose 256 for this length, simply to have some space in case
# we need it at some point (i.e., the tempfile format changes).
tem_input_name_qe = os.path.basename(tem_input_name)
tem_input_name_qe = tem_input_name_qe + " " * (
256 - len(tem_input_name_qe)
)
te.initialize(tem_input_name_qe, self.y_planes)
barrier()

# Right after setup we can delete the file.
if get_rank() == 0:
os.remove(tem_input_name)

printout(
"Total energy module: Time used by total energy initialization: {:.8f}s".format(
time.perf_counter() - t0
Expand Down
30 changes: 28 additions & 2 deletions mala/targets/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import json
import os
import tempfile

from ase.neighborlist import NeighborList
from ase.units import Rydberg, kB
Expand All @@ -14,7 +15,12 @@
from scipy.integrate import simpson

from mala.common.parameters import Parameters, ParametersTargets
from mala.common.parallelizer import printout, parallel_warn, get_rank
from mala.common.parallelizer import (
printout,
parallel_warn,
get_rank,
get_comm,
)
from mala.targets.calculation_helpers import fermi_function
from mala.common.physical_data import PhysicalData
from mala.descriptors.atomic_density import AtomicDensity
Expand Down Expand Up @@ -1333,6 +1339,8 @@ def write_tem_input_file(
qe_pseudopotentials,
grid_dimensions,
kpoints,
mpi_communicator,
mpi_rank,
):
"""
Write a QE-style input file for the total energy module.
Expand Down Expand Up @@ -1360,6 +1368,14 @@ def write_tem_input_file(
kpoints : dict
k-grid used, usually None or (1,1,1) for TEM calculations.
mpi_communicator : MPI.COMM_WORLD
An MPI comminucator. If no MPI is enabled, this will simply be
None.
mpi_rank : int
Rank within MPI
"""
# Specify grid dimensions, if any are given.
if (
Expand All @@ -1379,14 +1395,24 @@ def write_tem_input_file(
# the DFT calculation. If symmetry is then on in here, that
# leads to errors.
# qe_input_data["nosym"] = False
if mpi_rank == 0:
tem_input_file = tempfile.NamedTemporaryFile(
delete=False, prefix="mala.pw.scf.", suffix=".in", dir="./"
).name
else:
tem_input_file = None

if mpi_communicator is not None:
tem_input_file = mpi_communicator.bcast(tem_input_file, root=0)
ase.io.write(
"mala.pw.scf.in",
tem_input_file,
atoms_Angstrom,
"espresso-in",
input_data=qe_input_data,
pseudopotentials=qe_pseudopotentials,
kpts=kpoints,
)
return tem_input_file

def restrict_data(self, array):
"""
Expand Down

0 comments on commit fd221a6

Please sign in to comment.