Skip to content

Commit

Permalink
Feat: pop None inputs specified in overrides
Browse files Browse the repository at this point in the history
Fixes #653

Add the possibility of popping input namespaces by specifying
None in the override for the specific namespace. A decorator
is added that generalize the concept to any implementation of
get_builder_from_protocol.
  • Loading branch information
bastonero committed Mar 29, 2024
1 parent ae7d248 commit 75ab490
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 0 deletions.
58 changes: 58 additions & 0 deletions src/aiida_quantumespresso/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
"""Decorators for several purposes."""


def remove_none_overrides(func):
"""Remove namespaces of the returned builder of a `get_builder*` method."""

def recursively_remove_nones(item):
"""Recursively remove keys with None values from dictionaries."""
if isinstance(item, dict):
return {key: recursively_remove_nones(value) for key, value in item.items() if value is not None}
return item

def remove_keys_from_builder(builder, keys, path=()):
"""Recursively remove specified keys from the builder based on a path."""
if not keys:
return
current_level = keys.pop(0)
if hasattr(builder, current_level):
if keys:
next_attr = getattr(builder, current_level)
remove_keys_from_builder(next_attr, keys, path + (current_level,))

delattr(builder, current_level)

def wrapper(*args, **kwargs):
if 'overrides' in kwargs and kwargs['overrides'] is not None:
original_overrides = kwargs['overrides']

# Identify paths to keys with None values to be removed
paths_to_remove = []

def find_paths(item, path=()):
if isinstance(item, dict):
for key, value in item.items():
if value is None:
paths_to_remove.append(path + (key,))

find_paths(value, path + (key,))

find_paths(original_overrides)

# Recursively remove keys with None values from overrides
cleaned_overrides = recursively_remove_nones(original_overrides)
kwargs['overrides'] = cleaned_overrides

# Call the original function to get the builder
builder = func(*args, **kwargs)

# Remove specified keys from the builder
for path in paths_to_remove:
remove_keys_from_builder(builder, list(path))

return builder

return func(*args, **kwargs)

return wrapper
2 changes: 2 additions & 0 deletions src/aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiida.engine import ToContext, WorkChain, if_

from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis
from aiida_quantumespresso.utils.decorators import remove_none_overrides
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
Expand Down Expand Up @@ -120,6 +121,7 @@ def get_protocol_filepath(cls):
return files(pw_protocols) / 'bands.yaml'

@classmethod
@remove_none_overrides
def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=None, options=None, **kwargs):
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
Expand Down
2 changes: 2 additions & 0 deletions src/aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import create_kpoints_from_distance
from aiida_quantumespresso.common.types import ElectronicType, RestartType, SpinType
from aiida_quantumespresso.utils.decorators import remove_none_overrides
from aiida_quantumespresso.utils.defaults.calculation import pw as qe_defaults

from ..protocols.utils import ProtocolMixin
Expand Down Expand Up @@ -103,6 +104,7 @@ def get_protocol_filepath(cls):
return files(pw_protocols) / 'base.yaml'

@classmethod
@remove_none_overrides
def get_builder_from_protocol(
cls,
code,
Expand Down
26 changes: 26 additions & 0 deletions tests/workflows/protocols/pw/test_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,29 @@ def test_options(fixture_code, generate_structure):
builder.bands.pw.metadata, # pylint: disable=no-member
):
assert subspace['options']['queue_name'] == queue_name, subspace


def test_pop_none_overrides(fixture_code, generate_structure):
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()

overrides = {'relax': {'base_final_scf': None}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'base_final_scf' not in builder['relax'] # pylint: disable=no-member

overrides = {'relax': None}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'relax' not in builder # pylint: disable=no-member

overrides = {'relax': {'base': {'pw': {'parameters': {'SYSTEM': {'ecutwfc': None}}}}}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'ecutwfc' in builder['relax']['base']['pw']['parameters']['SYSTEM'] # pylint: disable=no-member

overrides = {'relax': {'base': {'pw': {'parameters': None}}}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'parameters' not in builder['relax']['base']['pw'] # pylint: disable=no-member
11 changes: 11 additions & 0 deletions tests/workflows/protocols/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,14 @@ def test_options(fixture_code, generate_structure):

assert metadata['options']['queue_name'] == queue_name
assert metadata['options']['withmpi'] == withmpi


def test_pop_none_overrides(fixture_code, generate_structure):
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()

overrides = {'kpoints_distance': None}
builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'kpoints_distance' not in builder # pylint: disable=no-member

0 comments on commit 75ab490

Please sign in to comment.