Skip to content

Commit

Permalink
Slightly optimize retrieving set of extras fields (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
CasperWA authored Jul 15, 2021
1 parent 6090339 commit 0105b0a
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 45 deletions.
10 changes: 8 additions & 2 deletions aiida_optimade/cli/cmd_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,14 @@ def calc(obj: dict, fields: Tuple[str], force_yes: bool, silent: bool):
" This may take several minutes!"
)

STRUCTURES._filter_fields = {
STRUCTURES.resource_mapper.get_backend_field(_) for _ in fields
STRUCTURES._extras_fields = {
STRUCTURES.resource_mapper.get_backend_field(_)[
len(STRUCTURES.resource_mapper.PROJECT_PREFIX) :
]
for _ in fields
if STRUCTURES.resource_mapper.get_backend_field(_).startswith(
STRUCTURES.resource_mapper.PROJECT_PREFIX
)
}
updated_pks = STRUCTURES._check_and_calculate_entities(cli=not silent)
except click.Abort:
Expand Down
9 changes: 7 additions & 2 deletions aiida_optimade/cli/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,14 @@ def init(obj: dict, force: bool, silent: bool, mongo: bool, filename: str):
}
entries = [[_] for _ in entries]

STRUCTURES._filter_fields = {
STRUCTURES.resource_mapper.get_backend_field(_)
STRUCTURES._extras_fields = {
STRUCTURES.resource_mapper.get_backend_field(_)[
len(STRUCTURES.resource_mapper.PROJECT_PREFIX) :
]
for _ in STRUCTURES.resource_mapper.ALL_ATTRIBUTES
if STRUCTURES.resource_mapper.get_backend_field(_).startswith(
STRUCTURES.resource_mapper.PROJECT_PREFIX
)
}
updated_pks = STRUCTURES._check_and_calculate_entities(
cli=not silent,
Expand Down
80 changes: 40 additions & 40 deletions aiida_optimade/entry_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
# "Cache"
self._data_available: int = None
self._data_returned: int = None
self._filter_fields: Set[str] = None
self._extras_fields: Set[str] = None
self._latest_filter: Dict[str, Any] = None
self._count: Dict[str, Any] = None
self._checked_extras_filter_fields: set = set()
Expand Down Expand Up @@ -105,7 +105,7 @@ def _clear_cache(self) -> None:
"""Clear in-memory attributes cache"""
self._data_available: int = None
self._data_returned: int = None
self._filter_fields: set = None
self._extras_fields: set = None
self._latest_filter: dict = None
self._count: dict = None
self._checked_extras_filter_fields: set = set()
Expand Down Expand Up @@ -183,17 +183,15 @@ def find( # pylint: disable=too-many-branches
single_entry = isinstance(params, SingleEntryQueryParams)
response_fields = criteria.pop("fields", set())

if criteria.get("filters", {}) and self._get_extras_filter_fields():
for requested_extras_field in self._get_extras_filter_fields():
if criteria.get("filters", {}) and self._extras_fields:
for requested_extras_field in self._extras_fields:
if requested_extras_field not in self._checked_extras_filter_fields:
LOGGER.debug(
"Checking all extras fields have been calculated (and possibly "
"calculate them)."
)
self._check_and_calculate_entities()
self._checked_extras_filter_fields |= (
self._get_extras_filter_fields()
)
self._checked_extras_filter_fields |= self._extras_fields
break
else:
LOGGER.debug(
Expand Down Expand Up @@ -363,7 +361,7 @@ def handle_query_params(
# filter
if cursor_kwargs.get("filter", False):
cursor_kwargs["filters"] = cursor_kwargs.pop("filter")
self._set_filter_fields(cursor_kwargs["filters"])
self._find_extras_fields(cursor_kwargs["filters"])
else:
cursor_kwargs.pop("filter", None)

Expand Down Expand Up @@ -425,46 +423,48 @@ def parse_sort_params(self, sort_params: str) -> List[Dict[str, Dict[str, str]]]
)
return sort_spec

def _set_filter_fields(self, filters: Union[dict, list]) -> None:
"""Set all properties to be found in AiiDA Nodes."""
def _find_extras_fields(self, filters: Union[dict, list]) -> None:
"""Collect all properties to be found in AiiDA Node extras.
Parameters:
filters: The complete or part of the parsed and transformed `filter` query
parameter.
"""
from copy import deepcopy

def __filter_fields_util(
def __filter_fields_util( # pylint: disable=unused-private-member
_filters: Union[dict, list]
): # pylint: disable=unused-private-member
) -> Union[dict, list]:
if isinstance(_filters, dict):
res = {}
for key, value in _filters.items():
new_value = value
if isinstance(value, (dict, list)):
new_value = __filter_fields_util(value)
aliased_key = self.resource_mapper.get_backend_field(key)
res[aliased_key] = new_value
self._filter_fields.add(aliased_key)
res[key] = (
__filter_fields_util(value)
if isinstance(value, (dict, list))
else value
)
self._extras_fields |= {
key[len(self.resource_mapper.PROJECT_PREFIX) :]
for key in _filters
if key.startswith(self.resource_mapper.PROJECT_PREFIX)
}
elif isinstance(_filters, list):
res = []
for item in _filters:
new_value = item
if isinstance(item, (dict, list)):
new_value = __filter_fields_util(item)
res.append(new_value)
res = [
__filter_fields_util(item)
if isinstance(item, (dict, list))
else item
for item in _filters
]
else:
raise NotImplementedError(
"_alias_filter can only handle dict and list objects"
"_find_extras_fields can only handle dict and list objects."
)
return res

self._filter_fields = set()
self._extras_fields = set()
__filter_fields_util(deepcopy(filters))

def _get_extras_filter_fields(self) -> set:
"""Get all queried fields saved in Node extras."""
return {
field[len(self.resource_mapper.PROJECT_PREFIX) :] # noqa: E203
for field in self._filter_fields
if field.startswith(self.resource_mapper.PROJECT_PREFIX)
}

def _check_and_calculate_entities(
self, cli: bool = False, entries: List[List[int]] = None
) -> List[int]:
Expand All @@ -491,21 +491,19 @@ def _update_entities(entities: List[List[Any]], fields: List[str]):
for entity in entities:
field_to_entity_value = dict(zip(optimade_fields, entity))
retrieved_attributes = field_to_entity_value.copy()
for missing_attribute in self._get_extras_filter_fields():
for missing_attribute in self._extras_fields:
retrieved_attributes.pop(missing_attribute)
self.resource_mapper.build_attributes(
retrieved_attributes=retrieved_attributes,
entry_pk=field_to_entity_value["id"],
node_type=field_to_entity_value["type"],
missing_attributes=self._get_extras_filter_fields(),
missing_attributes=self._extras_fields,
)

extras_keys = [
key for key in self.resource_mapper.PROJECT_PREFIX.split(".") if key
]
filter_fields = [
{"!has_key": field} for field in self._get_extras_filter_fields()
]
filter_fields = [{"!has_key": field} for field in self._extras_fields]
necessary_entity_ids = (
self._find_all(
filters={
Expand All @@ -528,7 +526,9 @@ def _update_entities(entities: List[List[Any]], fields: List[str]):
# Create the missing OPTIMADE fields
fields = {"id", "type"}
fields |= self.get_attribute_fields()
fields |= {f"_{self.provider_prefix}_" + _ for _ in self.provider_fields}
fields |= {
f"_{self.provider_prefix}_{field}" for field in self.provider_fields
}
fields = list({self.resource_mapper.get_backend_field(_) for _ in fields})

entities = self._find_all(
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ ignore =
# Line to long. Handled by black.
E501
# Line break before binary operator. This is preferred formatting for black.
W503
W503
# Whitespace before ':'
E203

0 comments on commit 0105b0a

Please sign in to comment.