Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented tempfile for TEM #595

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions external_modules/total_energy_module/total_energy.f90
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SUBROUTINE initialize(y_planes_in, calculate_eigts_in)
SUBROUTINE initialize(file_name, y_planes_in, calculate_eigts_in)
!----------------------------------------------------------------------------
! Derived from Quantum Espresso code
!! author: Paolo Giannozzi
Expand Down Expand Up @@ -29,6 +29,7 @@ SUBROUTINE initialize(y_planes_in, calculate_eigts_in)
LOGICAL, INTENT(IN), OPTIONAL :: calculate_eigts_in
LOGICAL :: calculate_eigts = .false.
INTEGER, INTENT(IN), OPTIONAL :: y_planes_in
CHARACTER(len=256), INTENT(IN) :: file_name
! Parse optional arguments.
IF (PRESENT(calculate_eigts_in)) THEN
calculate_eigts = calculate_eigts_in
Expand All @@ -45,7 +46,7 @@ SUBROUTINE initialize(y_planes_in, calculate_eigts_in)
!
CALL environment_start ( 'PWSCF' )
!
CALL read_input_file ('PW', 'mala.pw.scf.in' )
CALL read_input_file ('PW', file_name )
CALL run_pwscf_setup ( exit_status, calculate_eigts)

print *, "Setup completed"
Expand Down
25 changes: 12 additions & 13 deletions mala/descriptors/bispectrum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Bispectrum descriptor class."""

import os
import tempfile

import ase
import ase.io
Expand Down Expand Up @@ -963,21 +962,21 @@ def __compute_ui(self, nr_atoms, atoms_cutoff, distances_cutoff, grid):
)
jju1 += 1
if jju_outer in self.__index_u1_symmetry_pos:
ulist_r_ij[
:, self.__index_u1_symmetry_pos[jju2]
] = ulist_r_ij[:, self.__index_u_symmetry_pos[jju2]]
ulist_i_ij[
:, self.__index_u1_symmetry_pos[jju2]
] = -ulist_i_ij[:, self.__index_u_symmetry_pos[jju2]]
ulist_r_ij[:, self.__index_u1_symmetry_pos[jju2]] = (
ulist_r_ij[:, self.__index_u_symmetry_pos[jju2]]
)
ulist_i_ij[:, self.__index_u1_symmetry_pos[jju2]] = (
-ulist_i_ij[:, self.__index_u_symmetry_pos[jju2]]
)
jju2 += 1

if jju_outer in self.__index_u1_symmetry_neg:
ulist_r_ij[
:, self.__index_u1_symmetry_neg[jju3]
] = -ulist_r_ij[:, self.__index_u_symmetry_neg[jju3]]
ulist_i_ij[
:, self.__index_u1_symmetry_neg[jju3]
] = ulist_i_ij[:, self.__index_u_symmetry_neg[jju3]]
ulist_r_ij[:, self.__index_u1_symmetry_neg[jju3]] = (
-ulist_r_ij[:, self.__index_u_symmetry_neg[jju3]]
)
ulist_i_ij[:, self.__index_u1_symmetry_neg[jju3]] = (
ulist_i_ij[:, self.__index_u_symmetry_neg[jju3]]
)
jju3 += 1

# This emulates add_uarraytot.
Expand Down
19 changes: 11 additions & 8 deletions mala/descriptors/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,33 +228,36 @@ def setup_lammps_tmp_files(self, lammps_type, outdir):
Type of descriptor calculation (e.g. bgrid for bispectrum)
outdir: str
Directory where lammps files are kept

Returns
-------
None
"""
if get_rank() == 0:
prefix_inp_str = "lammps_" + lammps_type + "_input"
prefix_log_str = "lammps_" + lammps_type + "_log"
lammps_tmp_input_file=tempfile.NamedTemporaryFile(
lammps_tmp_input_file = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix_inp_str, suffix="_.tmp", dir=outdir
)
self.lammps_temporary_input = lammps_tmp_input_file.name
lammps_tmp_input_file.close()

lammps_tmp_log_file=tempfile.NamedTemporaryFile(
lammps_tmp_log_file = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix_log_str, suffix="_.tmp", dir=outdir
)
self.lammps_temporary_log = lammps_tmp_log_file.name
lammps_tmp_log_file.close()
else:
self.lammps_temporary_input=None
self.lammps_temporary_log=None
self.lammps_temporary_input = None
self.lammps_temporary_log = None

if self.parameters._configuration["mpi"]:
self.lammps_temporary_input = get_comm().bcast(self.lammps_temporary_input, root=0)
self.lammps_temporary_log = get_comm().bcast(self.lammps_temporary_log, root=0)

self.lammps_temporary_input = get_comm().bcast(
self.lammps_temporary_input, root=0
)
self.lammps_temporary_log = get_comm().bcast(
self.lammps_temporary_log, root=0
)

# Calculations
##############
Expand Down
27 changes: 11 additions & 16 deletions mala/interfaces/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ase.calculators.calculator import Calculator, all_changes

from mala import Parameters, Network, DataHandler, Predictor, LDOS
from mala.common.parallelizer import barrier, parallel_warn
from mala.common.parallelizer import barrier, parallel_warn, get_rank, get_comm


class MALA(Calculator):
Expand Down Expand Up @@ -154,28 +154,23 @@ def calculate(
# Get the LDOS from the NN.
ldos = self.predictor.predict_for_atoms(atoms)

# forces = np.zeros([len(atoms), 3], dtype=np.float64)

# If an MPI environment is detected, ASE will use it for writing.
# Therefore we have to do this before forking.
self.data_handler.target_calculator.write_tem_input_file(
atoms,
self.data_handler.target_calculator.qe_input_data,
self.data_handler.target_calculator.qe_pseudopotentials,
self.data_handler.target_calculator.grid_dimensions,
self.data_handler.target_calculator.kpoints,
)

# Use the LDOS determined DOS and density to get energy and forces.
ldos_calculator: LDOS = self.data_handler.target_calculator

ldos_calculator.read_from_array(ldos)
self.results["energy"] = ldos_calculator.total_energy
energy, self.last_energy_contributions = (
ldos_calculator.get_total_energy(return_energy_contributions=True)
)
self.last_energy_contributions = (
ldos_calculator._density_calculator.total_energy_contributions.copy()
)
self.last_energy_contributions["e_band"] = ldos_calculator.band_energy
self.last_energy_contributions["e_entropy_contribution"] = (
ldos_calculator.entropy_contribution
)
barrier()

# Use the LDOS determined DOS and density to get energy and forces.
self.results["energy"] = energy
# forces = np.zeros([len(atoms), 3], dtype=np.float64)
# if "forces" in properties:
# self.results["forces"] = forces

Expand Down
25 changes: 22 additions & 3 deletions mala/targets/density.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Electronic density calculation class."""

import os.path
import time

from ase.units import Rydberg, Bohr, m
Expand All @@ -16,6 +17,8 @@
parallel_warn,
barrier,
get_size,
get_comm,
get_rank,
)
from mala.targets.target import Target
from mala.targets.cube_parser import read_cube, write_cube
Expand Down Expand Up @@ -406,7 +409,7 @@ def read_from_cube(self, path, units="1/Bohr^3", **kwargs):
printout("Reading density from .cube file ", path, min_verbosity=0)
# automatically convert units if they are None since cube files take atomic units
if units is None:
units="1/Bohr^3"
units = "1/Bohr^3"
if units != "1/Bohr^3":
printout(
"The expected units for the density from cube files are 1/Bohr^3\n"
Expand Down Expand Up @@ -960,12 +963,14 @@ def __setup_total_energy_module(
else:
kpoints = self.kpoints

self.write_tem_input_file(
tem_input_name = self.write_tem_input_file(
atoms_Angstrom,
qe_input_data,
qe_pseudopotentials,
self.grid_dimensions,
kpoints,
get_comm(),
get_rank(),
)

# initialize the total energy module.
Expand All @@ -984,8 +989,22 @@ def __setup_total_energy_module(
)
barrier()
t0 = time.perf_counter()
te.initialize(self.y_planes)

# We have to make sure we have the correct format for the file.
# QE expects the file without a path, and with a fixed length.
# I chose 256 for this length, simply to have some space in case
# we need it at some point (i.e., the tempfile format changes).
tem_input_name_qe = os.path.basename(tem_input_name)
tem_input_name_qe = tem_input_name_qe + " " * (
256 - len(tem_input_name_qe)
)
te.initialize(tem_input_name_qe, self.y_planes)
barrier()

# Right after setup we can delete the file.
if get_rank() == 0:
os.remove(tem_input_name)

printout(
"Total energy module: Time used by total energy initialization: {:.8f}s".format(
time.perf_counter() - t0
Expand Down
30 changes: 28 additions & 2 deletions mala/targets/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import json
import os
import tempfile

from ase.neighborlist import NeighborList
from ase.units import Rydberg, kB
Expand All @@ -14,7 +15,12 @@
from scipy.integrate import simpson

from mala.common.parameters import Parameters, ParametersTargets
from mala.common.parallelizer import printout, parallel_warn, get_rank
from mala.common.parallelizer import (
printout,
parallel_warn,
get_rank,
get_comm,
)
from mala.targets.calculation_helpers import fermi_function
from mala.common.physical_data import PhysicalData
from mala.descriptors.atomic_density import AtomicDensity
Expand Down Expand Up @@ -1333,6 +1339,8 @@ def write_tem_input_file(
qe_pseudopotentials,
grid_dimensions,
kpoints,
mpi_communicator,
mpi_rank,
):
"""
Write a QE-style input file for the total energy module.
Expand Down Expand Up @@ -1360,6 +1368,14 @@ def write_tem_input_file(

kpoints : dict
k-grid used, usually None or (1,1,1) for TEM calculations.

mpi_communicator : MPI.COMM_WORLD
An MPI comminucator. If no MPI is enabled, this will simply be
None.

mpi_rank : int
Rank within MPI

"""
# Specify grid dimensions, if any are given.
if (
Expand All @@ -1379,14 +1395,24 @@ def write_tem_input_file(
# the DFT calculation. If symmetry is then on in here, that
# leads to errors.
# qe_input_data["nosym"] = False
if mpi_rank == 0:
tem_input_file = tempfile.NamedTemporaryFile(
delete=False, prefix="mala.pw.scf.", suffix=".in", dir="./"
).name
else:
tem_input_file = None

if mpi_communicator is not None:
tem_input_file = mpi_communicator.bcast(tem_input_file, root=0)
ase.io.write(
"mala.pw.scf.in",
tem_input_file,
atoms_Angstrom,
"espresso-in",
input_data=qe_input_data,
pseudopotentials=qe_pseudopotentials,
kpts=kpoints,
)
return tem_input_file

def restrict_data(self, array):
"""
Expand Down