Skip to content

Commit

Permalink
Improved schema with new workflow base sections
Browse files Browse the repository at this point in the history
Deleted unused methods

Improved docstrings
  • Loading branch information
JosePizarro3 committed Oct 10, 2024
1 parent 20e0277 commit 447535d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 99 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ maintainers = [
]
license = { file = "LICENSE" }
dependencies = [
"nomad-lab>=1.3.0",
"nomad-lab@file:///home/josepizarro/nomad",
"matid>=2.0.0.dev2",
]

Expand Down
29 changes: 13 additions & 16 deletions src/nomad_simulations/schema_packages/workflow/base_workflows.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
from functools import wraps
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from nomad.datamodel.datamodel import EntryArchive
from structlog.stdlib import BoundLogger

from nomad.datamodel.data import ArchiveSection
from nomad.datamodel.metainfo.workflow import TaskReference, Workflow
from nomad.datamodel.metainfo.workflow_new import BaseTask
from nomad.datamodel.metainfo.workflow_new import Workflow2 as Workflow
from nomad.metainfo import SubSection

from nomad_simulations.schema_packages.model_method import BaseModelMethod
from nomad_simulations.schema_packages.outputs import Outputs


def check_n_tasks(n_tasks: Optional[int] = None):
def check_n_tasks(n_tasks: int = 1):
"""
Check if the `tasks` of a workflow exist. If the `n_tasks` input specified, it checks whether `tasks`
is of the same length as `n_tasks`.
Check if the `tasks` of a workflow exist. It checks whether `tasks` is of the same length as `n_tasks`.
Args:
n_tasks (Optional[int], optional): The length of the `tasks` needs to be checked if set to an integer. Defaults to None.
n_tasks (int): The length of the `tasks` needs to be checked if set to an integer. Defaults to 1.
"""

def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.tasks:
return None
if n_tasks is not None and len(self.tasks) != n_tasks:
if not self.tasks or len(self.tasks) != n_tasks:
return None

return func(self, *args, **kwargs)
Expand All @@ -39,14 +37,14 @@ def wrapper(self, *args, **kwargs):

class SimulationWorkflow(Workflow):
"""
A base section used to define the workflows of a simulation with references to specific `tasks`, `inputs`, and `outputs`. The
A base section used to define the workflows of a simulation with specific `tasks`, `inputs`, and `outputs`. The
normalize function checks the definition of these sections and sets the name of the workflow.
A `SimulationWorkflow` will be composed of:
- a `method` section containing methodological parameters used specifically during the workflow,
- a list of `inputs` with references to the `ModelSystem` and, optionally, `ModelMethod` input sections,
- a list of `outputs` with references to the `Outputs` section,
- a list of `tasks` containing references to the activity `Simulation` used in the workflow,
- a list of `tasks` containing references or the section information of the `task` used in the workflow,
"""

method = SubSection(
Expand All @@ -66,7 +64,7 @@ class BeyondDFTMethod(ArchiveSection):
"""
An abstract section used to store references to the `ModelMethod` sections of each of the
archives defining the `tasks` and used to build the standard `BeyondDFT` workflow. This section needs to be
inherit and the method references need to be defined for each specific case (see, e.g., dft_plus_tb.py module).
inherit and the method references need to be defined for each specific case (see, e.g., `dft_plus_tb.py` module).
"""

pass
Expand Down Expand Up @@ -104,16 +102,15 @@ def resolve_all_outputs(self) -> list[Outputs]:
all_outputs.append(task.outputs[-1])
return all_outputs

@check_n_tasks()
def resolve_method_refs(
self, tasks: list[TaskReference], tasks_names: list[str]
self, tasks: list[BaseTask], tasks_names: list[str]
) -> list[BaseModelMethod]:
"""
Resolve the references to the `BaseModelMethod` sections in the list of `tasks`. This is useful
when defining the `method` section of the `BeyondDFT` workflow.
Args:
tasks (list[TaskReference]): The list of tasks from which resolve the `BaseModelMethod` sections.
tasks (list[BaseTask]): The list of tasks from which resolve the `BaseModelMethod` sections.
tasks_names (list[str]): The list of names for each of the tasks forming the BeyondDFT workflow.
Returns:
Expand All @@ -132,7 +129,7 @@ def resolve_method_refs(
if not task.m_xpath('task.inputs'):
continue

# Resolve the method of each task.inputs
# Resolve the method of each `tasks[*].task.inputs`
for input in task.task.inputs:
if isinstance(input.section, BaseModelMethod):
method_refs.append(input.section)
Expand Down
99 changes: 30 additions & 69 deletions src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from nomad.datamodel.datamodel import EntryArchive
from structlog.stdlib import BoundLogger

from nomad.datamodel.metainfo.workflow import Link, TaskReference
from nomad.metainfo import Quantity, Reference
from nomad.datamodel.metainfo.workflow_new import LinkReference
from nomad.metainfo import Quantity
from nomad.utils import extract_section

from nomad_simulations.schema_packages.model_method import DFT, TB
from nomad_simulations.schema_packages.workflow import BeyondDFT, BeyondDFTMethod
Expand All @@ -21,13 +22,13 @@ class DFTPlusTBMethod(BeyondDFTMethod):
"""

dft_method_ref = Quantity(
type=Reference(DFT),
type=DFT,
description="""
Reference to the DFT `ModelMethod` section in the DFT task.
""",
)
tb_method_ref = Quantity(
type=Reference(TB),
type=TB,
description="""
Reference to the TB `ModelMethod` section in the TB task.
""",
Expand All @@ -40,12 +41,10 @@ class DFTPlusTB(BeyondDFT):
two tasks: the initial DFT calculation + the final TB projection.
The section only needs to be populated with (everything else is handled by the `normalize` function):
i. The `tasks` as `TaskReference` sections, adding `task` to the specific archive.workflow2 sections.
ii. The `inputs` and `outputs` as `Link` sections pointing to the specific archives.
i. The `tasks` as `TaskReference` sections, adding `task` to the specific `archive.workflow2` sections.
Note 1: the `inputs[0]` of the `DFTPlusTB` coincides with the `inputs[0]` of the DFT task (`ModelSystem` section).
Note 2: the `outputs[-1]` of the `DFTPlusTB` coincides with the `outputs[-1]` of the TB task (`Outputs` section).
Note 3: the `outputs[-1]` of the DFT task is used as `inputs[0]` of the TB task.
The archive.workflow2 section is:
- name = 'DFT+TB'
Expand All @@ -54,68 +53,33 @@ class DFTPlusTB(BeyondDFT):
tb_method_ref=tb_archive.data.model_method[-1],
)
- inputs = [
Link(name='Input Model System', section=dft_archive.data.model_system[0]),
LinkReference(name='Input Model System', section=dft_archive.data.model_system[0]),
]
- outputs = [
Link(name='Output TB Data', section=tb_archive.data.outputs[-1]),
LinkReference(name='Output TB Data', section=tb_archive.data.outputs[-1]),
]
- tasks = [
TaskReference(
name='DFT SinglePoint Task',
task=dft_archive.workflow2
inputs=[
Link(name='Input Model System', section=dft_archive.data.model_system[0]),
],
outputs=[
Link(name='Output DFT Data', section=dft_archive.data.outputs[-1]),
]
),
TaskReference(
name='TB SinglePoint Task',
task=tb_archive.workflow2,
inputs=[
Link(name='Output DFT Data', section=dft_archive.data.outputs[-1]),
],
outputs=[
Link(name='Output tb Data', section=tb_archive.data.outputs[-1]),
]
),
TaskReference(task=dft_archive.workflow2),
TaskReference(task=tb_archive.workflow2),
]
"""

@check_n_tasks(n_tasks=2)
def link_task_inputs_outputs(
self, tasks: list[TaskReference], logger: 'BoundLogger'
) -> None:
if not self.inputs or not self.outputs:
logger.warning(
'The `DFTPlusTB` workflow needs to have `inputs` and `outputs` defined in order to link with the `tasks`.'
)
def resolve_inputs_outputs(self) -> None:
"""
Resolve the `inputs` and `outputs` of the `DFTPlusTB` workflow.
"""
input = extract_section(self.tasks[0], ['task', 'inputs[0]', 'section'])
if not input:
return None
print(input)
self.inputs = [LinkReference(name='Input Model System', section=input)]

dft_task = tasks[0]
tb_task = tasks[1]

# Initial check
if not dft_task.m_xpath('task.outputs'):
output = extract_section(self.tasks[1], ['task', 'outputs[-1]', 'section'])
if not output:
return None

# Input of DFT Task is the ModelSystem
dft_task.inputs = [
Link(name='Input Model System', section=self.inputs[0]),
]
# Output of DFT Task is the output section of the DFT entry
dft_task.outputs = [
Link(name='Output DFT Data', section=dft_task.task.outputs[-1]),
]
# Input of TB Task is the output of the DFT task
tb_task.inputs = [
Link(name='Output DFT Data', section=dft_task.task.outputs[-1]),
]
# Output of TB Task is the output section of the TB entry
tb_task.outputs = [
Link(name='Output TB Data', section=self.outputs[-1]),
]
print(output)
self.outputs = [LinkReference(name='Output TB Data', section=output)]

# TODO check if implementing overwritting the FermiLevel.value in the TB entry from the DFT entry

Expand Down Expand Up @@ -144,14 +108,11 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
tasks=self.tasks,
tasks_names=['DFT SinglePoint Task', 'TB SinglePoint Task'],
)
if method_refs is not None:
method_workflow = DFTPlusTBMethod()
for method in method_refs:
if isinstance(method, DFT):
method_workflow.dft_method_ref = method
elif isinstance(method, TB):
method_workflow.tb_method_ref = method
self.method = method_workflow

# Resolve `tasks[*].inputs` and `tasks[*].outputs`
self.link_task_inputs_outputs(tasks=self.tasks, logger=logger)
if method_refs is not None and len(method_refs) == 2:
print(method_refs)
self.method = DFTPlusTBMethod(
dft_method_ref=method_refs[0], tb_method_ref=method_refs[1]
)

# Resolve `inputs` and `outputs` from the `tasks`
self.resolve_inputs_outputs()
30 changes: 17 additions & 13 deletions src/nomad_simulations/schema_packages/workflow/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from nomad.datamodel.datamodel import EntryArchive
from structlog.stdlib import BoundLogger

from nomad.datamodel.metainfo.workflow import Link
from nomad.datamodel.metainfo.workflow_new import LinkReference
from nomad.metainfo import Quantity
from nomad.utils import extract_section

from nomad_simulations.schema_packages.outputs import SCFOutputs
from nomad_simulations.schema_packages.utils import extract_all_simulation_subsections
from nomad_simulations.schema_packages.workflow import SimulationWorkflow


Expand All @@ -26,11 +26,11 @@ class SinglePoint(SimulationWorkflow):
The archive.workflow2 section is:
- name = 'SinglePoint'
- inputs = [
Link(name='Input Model System', section=archive.data.model_system[0]),
Link(name='Input Model Method', section=archive.data.model_method[-1]),
LinkReference(name='Input Model System', section=archive.data.model_system[0]),
LinkReference(name='Input Model Method', section=archive.data.model_method[-1]),
]
- outputs = [
Link(name='Output Data', section=archive.data.outputs[-1]),
LinkReference(name='Output Data', section=archive.data.outputs[-1]),
]
- tasks = []
"""
Expand All @@ -53,19 +53,23 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
self.name = 'SinglePoint'

# Define `inputs` and `outputs`
input_model_system, input_model_method, output = (
extract_all_simulation_subsections(archive=archive)
)
if not input_model_system or not input_model_method or not output:
input_model_system = extract_section(archive, ['data', 'model_system'])
output = extract_section(archive, ['data', 'outputs'])
if not input_model_system or not output:
logger.warning(
'Could not find the ModelSystem, ModelMethod, or Outputs section in the archive.data section of the SinglePoint entry.'
'Could not find the `ModelSystem` or `Outputs` section in the archive.data section of the SinglePoint entry.'
)
return
self.inputs = [
Link(name='Input Model System', section=input_model_system),
Link(name='Input Model Method', section=input_model_method),
LinkReference(name='Input Model System', section=input_model_system),
]
self.outputs = [Link(name='Output Data', section=output)]
self.outputs = [LinkReference(name='Output Data', section=output)]
# `ModelMethod` is optional when defining workflows like the `SinglePoint`
input_model_method = extract_section(archive, ['data', 'model_method'])
if input_model_method is not None:
self.inputs.append(
LinkReference(name='Input Model Method', section=input_model_method)
)

# Resolve the `n_scf_steps` if the output is of `SCFOutputs` type
if isinstance(output, SCFOutputs):
Expand Down

0 comments on commit 447535d

Please sign in to comment.