Skip to content

Commit

Permalink
Improve aggregation and status performance (no conditions). (#818)
Browse files Browse the repository at this point in the history
* Use cached_statepoint.

* Require signac 2.2.0 for cached_statepoint.

* Cache the list of job ids while buffered.

This allows faster `job in project` tests and iteration over jobs.

Also remove some expensive open_job calls and job in project checks that are
not needed while registering aggregates.

* Do not iterate over all jobs for labels when there are no labels defined.

This saves a small amount of absolute time in projects with no labels. It also gives
the *appearance* of faster status checks as the user sees only 1 progress bar.

Also, hide the "labels" section of the status output when there are no labels to show.

* Run pre-commit.

* Suggest cached_statepoint usage in pre conditions.

* Update change log.

* Fix typo.

---------

Co-authored-by: Corwin Kerr <[email protected]>
  • Loading branch information
joaander and cbkerr authored Feb 15, 2024
1 parent 85b3b9b commit a5648af
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 34 deletions.
2 changes: 2 additions & 0 deletions changelog.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Changed
+++++++

- Move "Submit command" comment to end of pretend output (#805).
- Improve aggregate registration performance (#818).
- Hide empty "Labels" section in status output when there are no labels (#818).

Removed
+++++++
Expand Down
31 changes: 11 additions & 20 deletions flow/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,19 @@ def foo(*jobs):
if default is None:

def keyfunction(job):
return job.statepoint[key]
return job.cached_statepoint[key]

else:

def keyfunction(job):
return job.statepoint.get(key, default)
return job.cached_statepoint.get(key, default)

elif isinstance(key, Iterable):
keys = list(key)
if default is None:

def keyfunction(job):
return [job.statepoint[key] for key in keys]
return [job.cached_statepoint[key] for key in keys]

else:
if isinstance(default, Iterable):
Expand All @@ -264,7 +264,7 @@ def keyfunction(job):

def keyfunction(job):
return [
job.statepoint.get(key, default_value)
job.cached_statepoint.get(key, default_value)
for key, default_value in zip(keys, default)
]

Expand Down Expand Up @@ -430,11 +430,6 @@ def _register_aggregates(self):
# Initialize the internal mapping from id to aggregate
self._aggregates_by_id = {}
for aggregate in self._generate_aggregates():
for job in aggregate:
if job not in self._project:
raise LookupError(
f"The signac job {job.id} not found in {self._project}"
)
try:
stored_aggregate = tuple(aggregate)
except TypeError: # aggregate is not iterable
Expand All @@ -456,7 +451,7 @@ def _generate_aggregates(self):
else:

def sort_function(job):
return job.statepoint[self._aggregator._sort_by]
return job.cached_statepoint[self._aggregator._sort_by]

jobs = sorted(
jobs,
Expand Down Expand Up @@ -517,14 +512,7 @@ def __contains__(self, id):
The job id.
"""
try:
self._project.open_job(id=id)
except KeyError:
return False
except LookupError:
raise
else:
return True
return self._project._contains_job_id(job_id=id)

def __len__(self):
return len(self._project)
Expand All @@ -538,8 +526,11 @@ def __hash__(self):
return hash(self._project_repr)

def keys(self):
for job in self._project:
yield job.id
if self._project._is_buffered:
return self._project._jobs_cursor._ids
else:
for job in self._project:
yield job.id

def values(self):
for job in self._project:
Expand Down
60 changes: 50 additions & 10 deletions flow/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,11 @@ def hi_all(*jobs):
are used by :meth:`~.detect_operation_graph` when comparing
conditions for equality. The tag defaults to the bytecode of the
function.
.. tip::
Use ``job.cached_statepoint`` for the best performance in preconditions
that depend on the job's statepoint.
"""

_parent_class = parent_class
Expand Down Expand Up @@ -1746,6 +1751,9 @@ def __init__(self, path=None, environment=None, entrypoint=None):
format_checker=jsonschema.Draft7Validator.FORMAT_CHECKER,
)

self._is_buffered = False
self._jobs_cursor = None

# Associate this class with a compute environment.
self._environment = environment or get_environment()

Expand Down Expand Up @@ -1777,6 +1785,27 @@ def __init__(self, path=None, environment=None, entrypoint=None):
self._group_to_aggregate_store = _bidict()
self._register_groups()

def __iter__(self):
"""Provide a cached view of jobs while in a buffered state."""
if self._is_buffered:
return iter(self._jobs_cursor)
else:
return super().__iter__()

def __len__(self):
"""Provide a cached view of jobs while in a buffered state."""
if self._is_buffered:
return len(self._jobs_cursor._ids)
else:
return super().__len__()

def _contains_job_id(self, job_id):
"""Provide a cached view of jobs while in a buffered state."""
if self._is_buffered:
return job_id in self._jobs_cursor._id_set
else:
return super()._contains_job_id(job_id)

def _setup_template_environment(self):
"""Set up the jinja2 template environment.
Expand Down Expand Up @@ -2762,14 +2791,17 @@ def compute_status(data):
self._get_job_labels,
ignore_errors=ignore_errors,
)
job_labels = list(
parallel_executor(
compute_labels,
individual_jobs,
desc="Fetching labels",
file=err,
if len(self._label_functions) > 0:
job_labels = list(
parallel_executor(
compute_labels,
individual_jobs,
desc="Fetching labels",
file=err,
)
)
)
else:
job_labels = []

def combine_group_and_operation_status(aggregate_status_results):
group_statuses = {}
Expand Down Expand Up @@ -3113,10 +3145,10 @@ def display_group_name(group_name):
{
key
for job in individual_jobs
for key in job.statepoint.keys()
for key in job.cached_statepoint.keys()
if len(
{
_to_hashable(job.statepoint().get(key))
_to_hashable(job.cached_statepoint.get(key))
for job in individual_jobs
}
)
Expand Down Expand Up @@ -3156,7 +3188,7 @@ def dotted_get(mapping, key):
else:
parameter_name = parameter
if statepoint is None:
statepoint = job.statepoint()
statepoint = job.cached_statepoint
status["parameters"][parameter] = shorten(
str(self._alias(dotted_get(statepoint, parameter_name))),
param_max_width,
Expand Down Expand Up @@ -3981,10 +4013,18 @@ def _convert_jobs_to_aggregates(self, jobs):
def _buffered(self):
"""Enable the use of buffered mode for certain functions."""
logger.debug("Entering buffered mode.")

self._jobs_cursor = self.find_jobs()
self._is_buffered = True

with signac.buffered():
yield

logger.debug("Exiting buffered mode.")

self._is_buffered = False
self._jobs_cursor = None

def _generate_submit_script(
self, _id, operations, template, show_template_help, **kwargs
):
Expand Down
6 changes: 4 additions & 2 deletions flow/templates/base_status.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
Overview: {{ total_num_jobs_or_aggregates }} jobs/aggregates, {{ total_num_eligible_jobs_or_aggregates }} jobs/aggregates with eligible operations.

{% block progress %}
{% if progress_sorted|length > 0 %}
| label | ratio |
| ----- | ----- |
{% for label in progress_sorted %}
{% for label in progress_sorted %}
| {{ label[0] }} | {{ label[1]|draw_progress_bar(total_num_job_labels, '\\') }} |
{% endfor %}
{% endfor %}
{% endif %}
{% endblock progress %}

{% block operation_summary %}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [
]
dependencies = [
# The core package.
"signac>=2.0.0",
"signac>=2.2.0",
# For the templated generation of (submission) scripts.
"jinja2>=3.0.0",
# To enable the parallelized execution of operations across processes.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
signac>=2.0.0
signac>=2.2.0
jinja2>=3.0.0
cloudpickle>=1.6.0
deprecation>=2.0.0
Expand Down

0 comments on commit a5648af

Please sign in to comment.