diff --git a/setup.cfg b/setup.cfg index 99d372d..7d5b991 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,3 +19,10 @@ match=^[Tt]est [bdist_wheel] universal = 1 + +[black] +line-length = 88 + +[isort] +line_length = 88 +profile = black diff --git a/tests/fixtures/linter/linter-invalid-regex.yml b/tests/fixtures/linter/linter-invalid-regex.yml index 78cf912..1f205e6 100644 --- a/tests/fixtures/linter/linter-invalid-regex.yml +++ b/tests/fixtures/linter/linter-invalid-regex.yml @@ -6,7 +6,7 @@ tools: cores: 2 params: native_spec: "--mem {mem} --cores {cores} --gpus {gpus}" - bwa[0-9]++: + bwa[0-9]++\: gpus: 2 destinations: diff --git a/tests/fixtures/mapping-basic.yml b/tests/fixtures/mapping-basic.yml index f06869d..724f0ab 100644 --- a/tests/fixtures/mapping-basic.yml +++ b/tests/fixtures/mapping-basic.yml @@ -28,12 +28,6 @@ tools: scheduling: require: - non_existent - invalidly_tagged_tool: - scheduling: - require: - - general - reject: - - general regex_tool.*: scheduling: require: diff --git a/tests/fixtures/mapping-destinations.yml b/tests/fixtures/mapping-destinations.yml index 818a95c..e4f92fa 100644 --- a/tests/fixtures/mapping-destinations.yml +++ b/tests/fixtures/mapping-destinations.yml @@ -193,7 +193,7 @@ destinations: TEST_ENTITY_PRIORITY: "{cores*2}" params: memory_requests: "{mem*2}" - k8s_walltime_limit: 20 + k8s_walltime_limit: "20" rules: - if: input_size > 20 execute: | diff --git a/tests/fixtures/mapping-invalid-regex.yml b/tests/fixtures/mapping-invalid-regex.yml index ce3de75..ffc43d0 100644 --- a/tests/fixtures/mapping-invalid-regex.yml +++ b/tests/fixtures/mapping-invalid-regex.yml @@ -28,12 +28,6 @@ tools: scheduling: require: - non_existent - invalidly_tagged_tool: - scheduling: - require: - - general - reject: - - general regex_tool.*: scheduling: require: diff --git a/tests/fixtures/mapping-invalid-tags.yml b/tests/fixtures/mapping-invalid-tags.yml new file mode 100644 index 0000000..1adca66 --- /dev/null +++ b/tests/fixtures/mapping-invalid-tags.yml @@ -0,0 +1,40 @@ +global: + default_inherits: default + +tools: + default: + abstract: true + cores: 2 + mem: 8 + env: {} + scheduling: + require: [] + prefer: + - general + accept: + reject: + - pulsar + rules: [] + invalidly_tagged_tool: + scheduling: + require: + - general + reject: + - general + +destinations: + local: + runner: local + max_accepted_cores: 4 + max_accepted_mem: 16 + scheduling: + prefer: + - general + k8s_environment: + runner: k8s + max_accepted_cores: 16 + max_accepted_mem: 64 + max_accepted_gpus: 2 + scheduling: + prefer: + - pulsar diff --git a/tests/fixtures/mapping-rank.yml b/tests/fixtures/mapping-rank.yml index 192facd..0c45741 100644 --- a/tests/fixtures/mapping-rank.yml +++ b/tests/fixtures/mapping-rank.yml @@ -46,12 +46,6 @@ users: scheduling: require: - earth - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-role.yml b/tests/fixtures/mapping-role.yml index bca2058..2640d40 100644 --- a/tests/fixtures/mapping-role.yml +++ b/tests/fixtures/mapping-role.yml @@ -68,12 +68,6 @@ users: scheduling: require: - earth - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-rules-changed.yml b/tests/fixtures/mapping-rules-changed.yml index 6860f81..8861cd6 100644 --- a/tests/fixtures/mapping-rules-changed.yml +++ b/tests/fixtures/mapping-rules-changed.yml @@ -52,12 +52,6 @@ users: max_mem: cores * 6 - if: input_size >= 5 fail: Just because - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-rules.yml b/tests/fixtures/mapping-rules.yml index be5e53e..118f0ef 100644 --- a/tests/fixtures/mapping-rules.yml +++ b/tests/fixtures/mapping-rules.yml @@ -60,12 +60,6 @@ users: max_mem: cores * 6 - if: input_size >= 5 fail: Just because - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-syntax-error.yml b/tests/fixtures/mapping-syntax-error.yml index bbd3944..dfe972f 100644 --- a/tests/fixtures/mapping-syntax-error.yml +++ b/tests/fixtures/mapping-syntax-error.yml @@ -42,12 +42,6 @@ users: scheduling: require: - earth - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-user-invalid-tags.yml b/tests/fixtures/mapping-user-invalid-tags.yml new file mode 100644 index 0000000..aa880c2 --- /dev/null +++ b/tests/fixtures/mapping-user-invalid-tags.yml @@ -0,0 +1,50 @@ +global: + default_inherits: default + +tools: + default: + cores: 2 + mem: 8 + gpus: 1 + env: {} + scheduling: + require: [] + prefer: + - general + accept: + reject: + - pulsar + params: + native_spec: "--mem {mem} --cores {cores}" + rules: [] + +users: + default: + max_cores: 3 + max_mem: 4 + env: {} + scheduling: + require: [] + prefer: + - general + accept: + reject: + - pulsar + rules: [] + improbable@vortex.org: + scheduling: + require: + - pulsar + reject: + - pulsar + +destinations: + local: + runner: local + max_accepted_cores: 4 + max_accepted_mem: 16 + scheduling: + prefer: + - general + accept: + - pulsar diff --git a/tests/fixtures/mapping-user.yml b/tests/fixtures/mapping-user.yml index 3a49bc4..daaa206 100644 --- a/tests/fixtures/mapping-user.yml +++ b/tests/fixtures/mapping-user.yml @@ -60,12 +60,6 @@ users: - earth reject: - pulsar - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar prefect@vortex.org: max_cores: 4 max_mem: 32 diff --git a/tests/fixtures/scenario-too-many-highmem-jobs.yml b/tests/fixtures/scenario-too-many-highmem-jobs.yml index ea3e297..f512c85 100644 --- a/tests/fixtures/scenario-too-many-highmem-jobs.yml +++ b/tests/fixtures/scenario-too-many-highmem-jobs.yml @@ -63,8 +63,7 @@ users: rule_helper = RuleHelper(app) # Find all destinations that support highmem destinations = [d.dest_name for d in mapper.destinations.values() - if any(d.tpv_dest_tags.filter(tag_value='highmem', - tag_type=[TagType.REQUIRE, TagType.PREFER, TagType.ACCEPT]))] + if 'highmem' in (d.tpv_dest_tags.require + d.tpv_dest_tags.prefer + d.tpv_dest_tags.accept)] count = rule_helper.job_count(for_user_email=user.email, for_destinations=destinations) if count > 4: retval = True diff --git a/tests/test_entity.py b/tests/test_entity.py index 03621d2..171f96a 100644 --- a/tests/test_entity.py +++ b/tests/test_entity.py @@ -37,21 +37,20 @@ def test_all_entities_refer_to_same_loader(self): } # make sure we are still referring to the same loader after evaluation evaluated_entity = gateway.ACTIVE_DESTINATION_MAPPER.match_combine_evaluate_entities(context, tool, user) - assert evaluated_entity.loader == original_loader + assert evaluated_entity.evaluator == original_loader for rule in evaluated_entity.rules: - assert rule.loader == original_loader + assert rule.evaluator == original_loader def test_destination_to_dict(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml') loader = TPVConfigLoader.from_url_or_path(tpv_config) # create a destination - destination = loader.destinations["k8s_environment"] + destination = loader.config.destinations["k8s_environment"] # serialize the destination - serialized_destination = destination.to_dict() + serialized_destination = destination.dict() # deserialize the same destination - deserialized_destination = Destination.from_dict(loader, serialized_destination) + deserialized_destination = Destination(evaluator=loader, **serialized_destination) # make sure the deserialized destination is the same as the original self.assertEqual(deserialized_destination, destination) @@ -60,24 +59,23 @@ def test_tool_to_dict(self): loader = TPVConfigLoader.from_url_or_path(tpv_config) # create a tool - tool = loader.tools["limbo"] + tool = loader.config.tools["limbo"] # serialize the tool - serialized_destination = tool.to_dict() + serialized_tool = tool.dict() # deserialize the same tool - deserialized_destination = Tool.from_dict(loader, serialized_destination) + deserialized_tool = Tool(evaluator=loader, **serialized_tool) # make sure the deserialized tool is the same as the original - self.assertEqual(deserialized_destination, tool) + self.assertEqual(deserialized_tool, tool) def test_tag_equivalence(self): - tag1 = Tag("tag_name", "tag_value", TagType.REQUIRE) - tag2 = Tag("tag_name2", "tag_value", TagType.REQUIRE) - tag3 = Tag("tag_name", "tag_value1", TagType.REQUIRE) - tag4 = Tag("tag_name", "tag_value1", TagType.PREFER) - same_as_tag1 = Tag("tag_name", "tag_value", TagType.REQUIRE) + tag1 = Tag(value="tag_value", tag_type=TagType.REQUIRE) + tag2 = Tag(value="tag_value", tag_type=TagType.REQUIRE) + tag3 = Tag(value="tag_value1", tag_type=TagType.REQUIRE) + tag4 = Tag(value="tag_value1", tag_type=TagType.PREFER) + same_as_tag1 = Tag(value="tag_value", tag_type=TagType.REQUIRE) self.assertEqual(tag1, tag1) self.assertEqual(tag1, same_as_tag1) - self.assertNotEqual(tag1, tag2) self.assertNotEqual(tag1, tag3) self.assertNotEqual(tag1, tag4) self.assertNotEqual(tag1, "hello") diff --git a/tests/test_mapper_basic.py b/tests/test_mapper_basic.py index abe59ea..a70e079 100644 --- a/tests/test_mapper_basic.py +++ b/tests/test_mapper_basic.py @@ -35,8 +35,9 @@ def test_map_unschedulable_tool(self): def test_map_invalidly_tagged_tool(self): tool = mock_galaxy.Tool('invalidly_tagged_tool') - with self.assertRaisesRegex(JobMappingException, "No destinations are available to fulfill request"): - self._map_to_destination(tool) + config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-invalid-tags.yml') + with self.assertRaisesRegex(Exception, r"Duplicate tags found: 'general' in \['require', 'reject'\]"): + self._map_to_destination(tool, tpv_config_path=config) def test_map_tool_by_regex(self): tool = mock_galaxy.Tool('regex_tool_test') diff --git a/tests/test_mapper_user.py b/tests/test_mapper_user.py index 647a0d1..729f47f 100644 --- a/tests/test_mapper_user.py +++ b/tests/test_mapper_user.py @@ -8,10 +8,10 @@ class TestMapperUser(unittest.TestCase): @staticmethod - def _map_to_destination(tool, user): + def _map_to_destination(tool, user, tpv_config_path=None): galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) job = mock_galaxy.Job() - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-user.yml') + tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__), 'fixtures/mapping-user.yml') gateway.ACTIVE_DESTINATION_MAPPER = None return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) @@ -40,8 +40,9 @@ def test_map_invalidly_tagged_user(self): tool = mock_galaxy.Tool('bwa') user = mock_galaxy.User('infinitely', 'improbable@vortex.org') - with self.assertRaises(IncompatibleTagsException): - self._map_to_destination(tool, user) + config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-user-invalid-tags.yml') + with self.assertRaisesRegex(Exception, r"Duplicate tags found: 'pulsar' in \['require', 'reject'\]"): + self._map_to_destination(tool, user, tpv_config_path=config) def test_map_user_by_regex(self): tool = mock_galaxy.Tool('bwa') diff --git a/tox.ini b/tox.ini index 0883ee8..04c5cad 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ # running the tests. [tox] -envlist = py3.10,lint +envlist = py3.11,lint [testenv] commands = # see setup.cfg for options sent to nosetests and coverage diff --git a/tpv/__init__.py b/tpv/__init__.py index 76a6c5c..3abd85a 100644 --- a/tpv/__init__.py +++ b/tpv/__init__.py @@ -1,7 +1,7 @@ """Total Perspective Vortex library setup.""" # Current version of the library -__version__ = "2.4.0" +__version__ = "3.0.0" def get_version(): diff --git a/tpv/commands/linter.py b/tpv/commands/linter.py index 08e232e..dfc4315 100644 --- a/tpv/commands/linter.py +++ b/tpv/commands/linter.py @@ -23,8 +23,13 @@ def lint(self): except Exception as e: log.error(f"Linting failed due to syntax errors in yaml file: {e}") raise TPVLintError("Linting failed due to syntax errors in yaml file: ") from e - default_inherits = loader.global_settings.get('default_inherits') - for tool_regex, tool in loader.tools.items(): + self.lint_tools(loader) + self.lint_destinations(loader) + self.print_errors_and_warnings() + + def lint_tools(self, loader): + default_inherits = loader.config.global_config.default_inherits + for tool_regex, tool in loader.config.tools.items(): try: re.compile(tool_regex) except re.error: @@ -34,7 +39,10 @@ def lint(self): f"The tool named: {default_inherits} is marked globally as the tool to inherit from " "by default. You may want to mark it as abstract if it is not an actual tool and it " "will be excluded from scheduling decisions.") - for destination in loader.destinations.values(): + + def lint_destinations(self, loader): + default_inherits = loader.config.global_config.default_inherits + for destination in loader.config.destinations.values(): if not destination.runner and not destination.abstract: self.errors.append(f"Destination '{destination.id}' does not define the runner parameter. " "The runner parameter is mandatory.") @@ -51,6 +59,8 @@ def lint(self): f"The destination named: {default_inherits} is marked globally as the destination to inherit from " "by default. You may want to mark it as abstract if it is not meant to be dispatched to, and it " "will be excluded from scheduling decisions.") + + def print_errors_and_warnings(self): if self.warnings: for w in self.warnings: log.warning(w) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 1e214a5..652273e 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -1,362 +1,361 @@ -from __future__ import annotations - -from enum import Enum -import logging import copy +import itertools +import logging +from collections import defaultdict +from dataclasses import dataclass +from enum import IntEnum +from typing import Annotated, Any, ClassVar, Dict, Iterable, List, Optional from galaxy import util as galaxy_util +from pydantic import ( + BaseModel, + ConfigDict, + Field, + model_validator, +) +from pydantic.json_schema import SkipJsonSchema + +from .evaluator import TPVCodeEvaluator log = logging.getLogger(__name__) -class TagType(Enum): - REQUIRE = 2 - PREFER = 1 - ACCEPT = 0 - REJECT = -1 +class TryNextDestinationOrFail(Exception): + # Try next destination, fail job if destination options exhausted + pass - def __int__(self): - return self.value +class TryNextDestinationOrWait(Exception): + # Try next destination, raise JobNotReadyException if destination options exhausted + pass -class Tag: - def __init__(self, name, value, tag_type: Enum): - self.name = name - self.value = value - self.tag_type = tag_type +@dataclass +class TPVFieldMetadata: + complex_property: bool = False - def __eq__(self, other): - if not isinstance(other, Tag): - # don't attempt to compare against unrelated types - return NotImplemented - return self.name == other.name and self.value == other.value and self.tag_type == other.tag_type +def default_field_copier(entity1, entity2, property_name): + # if property_name in entity1.model_fields_set + return ( + getattr( + entity1, + property_name, + ) + if getattr(entity1, property_name, None) is not None + else getattr(entity2, property_name, None) + ) - def __repr__(self): - return f"" +def default_dict_copier(entity1, entity2, property_name): + new_dict = copy.deepcopy(getattr(entity2, property_name)) or {} + new_dict.update(copy.deepcopy(getattr(entity1, property_name)) or {}) + return new_dict -class IncompatibleTagsException(Exception): - def __init__(self, first_set, second_set): +class TagType(IntEnum): + REQUIRE = 2 + PREFER = 1 + ACCEPT = 0 + REJECT = -1 + + +@dataclass(frozen=True) +class Tag: + value: str + tag_type: TagType + +class IncompatibleTagsException(Exception): + def __init__(self, first_set: "SchedulingTags", second_set: "SchedulingTags"): super().__init__( - f"Cannot combine tag sets because require and reject tags mismatch. First tag set requires:" - f" {[tag.value for tag in first_set.filter(TagType.REQUIRE)]} and rejects:" - f" {[tag.value for tag in first_set.filter(TagType.REJECT)]}. Second tag set requires:" - f" {[tag.value for tag in second_set.filter(TagType.REQUIRE)]} and rejects:" - f" {[tag.value for tag in second_set.filter(TagType.REJECT)]}.") + "Cannot combine tag sets because require and reject tags mismatch. First" + f" tag set requires: {first_set.require} and rejects: {first_set.reject}." + f" Second tag set requires: {second_set.require} and rejects:" + f" {second_set.reject}." + ) -class TryNextDestinationOrFail(Exception): - # Try next destination, fail job if destination options exhausted - pass +class SchedulingTags(BaseModel): + require: Optional[List[str]] = Field(default_factory=list) + prefer: Optional[List[str]] = Field(default_factory=list) + accept: Optional[List[str]] = Field(default_factory=list) + reject: Optional[List[str]] = Field(default_factory=list) + @model_validator(mode="after") + def check_duplicates(self): + tag_occurrences = defaultdict(list) -class TryNextDestinationOrWait(Exception): - # Try next destination, raise JobNotReadyException if destination options exhausted - pass + # Track tag occurrences within each category and across categories + for tag_type in TagType: + field = tag_type.name.lower() + tags = getattr(self, field, []) or [] + for tag in tags: + tag_occurrences[tag].append(field) + + # Identify duplicates + duplicates = { + tag: fields for tag, fields in tag_occurrences.items() if len(fields) > 1 + } + # Build the detailed error message + if duplicates: + details = "; ".join( + [f"'{tag}' in {fields}" for tag, fields in duplicates.items()] + ) + raise ValueError(f"Duplicate tags found: {details}") -class TagSetManager(object): + return self - def __init__(self, tags=[]): - self.tags = tags or [] + @property + def tags(self) -> Iterable[Tag]: + return itertools.chain( + (Tag(value=tag, tag_type=TagType.REQUIRE) for tag in self.require or []), + (Tag(value=tag, tag_type=TagType.PREFER) for tag in self.prefer or []), + (Tag(value=tag, tag_type=TagType.ACCEPT) for tag in self.accept or []), + (Tag(value=tag, tag_type=TagType.REJECT) for tag in self.reject or []), + ) - def add_tag_override(self, tag: Tag): - # pop the tag if it exists, as a tag can only belong to one type - self.tags = list(filter(lambda t: t.value != tag.value, self.tags)) - self.tags.append(tag) + def all_tag_values(self) -> Iterable[str]: + return itertools.chain( + self.require or [], self.prefer or [], self.accept or [], self.reject or [] + ) - def filter(self, tag_type: TagType | list[TagType] = None, - tag_name: str = None, tag_value: str = None) -> list[Tag]: + def filter( + self, tag_type: TagType | list[TagType] = None, tag_value: str = None + ) -> list[Tag]: filtered = self.tags if tag_type: if isinstance(tag_type, TagType): filtered = (tag for tag in filtered if tag.tag_type == tag_type) else: filtered = (tag for tag in filtered if tag.tag_type in tag_type) - if tag_name: - filtered = (tag for tag in filtered if tag.name == tag_name) if tag_value: filtered = (tag for tag in filtered if tag.value == tag_value) return filtered - def add_tag_overrides(self, tags: list[Tag]): - for tag in tags: - self.add_tag_override(tag) - - def can_combine(self, other: TagSetManager) -> bool: - self_required = ((t.name, t.value) for t in self.filter(TagType.REQUIRE)) - other_required = ((t.name, t.value) for t in other.filter(TagType.REQUIRE)) - self_rejected = ((t.name, t.value) for t in self.filter(TagType.REJECT)) - other_rejected = ((t.name, t.value) for t in other.filter(TagType.REJECT)) - if set(self_required).intersection(set(other_rejected)): - return False - elif set(self_rejected).intersection(set(other_required)): + def add_tag_override(self, tag_type: TagType, tag_value: str): + # Remove tag from all categories + for field in TagType: + field_name = field.name.lower() + if tag_value in (getattr(self, field_name) or []): + getattr(self, field_name).remove(tag_value) + + # Add tag to the specified category + tag_field = tag_type.name.lower() + current_tags = getattr(self, tag_field, []) or [] + setattr(self, tag_field, current_tags + [tag_value]) + + def inherit(self, other: "SchedulingTags") -> "SchedulingTags": + # Create new lists of tags that combine self and other + new_tags = copy.deepcopy(other) + for tag_type in [ + TagType.ACCEPT, + TagType.PREFER, + TagType.REQUIRE, + TagType.REJECT, + ]: + for tag in getattr(self, tag_type.name.lower()) or []: + new_tags.add_tag_override(tag_type, tag) + return new_tags + + def can_combine(self, other: "SchedulingTags") -> bool: + self_required = set(self.require or []) + other_required = set(other.require or []) + self_rejected = set(self.reject or []) + other_rejected = set(other.reject or []) + + if self_required.intersection(other_rejected) or self_rejected.intersection( + other_required + ): return False - else: - return True + return True - def inherit(self, other) -> TagSetManager: - assert type(self) is type(other) - new_tag_set = TagSetManager() - new_tag_set.add_tag_overrides(other.filter(TagType.ACCEPT)) - new_tag_set.add_tag_overrides(other.filter(TagType.PREFER)) - new_tag_set.add_tag_overrides(other.filter(TagType.REQUIRE)) - new_tag_set.add_tag_overrides(other.filter(TagType.REJECT)) - new_tag_set.add_tag_overrides(self.filter(TagType.ACCEPT)) - new_tag_set.add_tag_overrides(self.filter(TagType.PREFER)) - new_tag_set.add_tag_overrides(self.filter(TagType.REQUIRE)) - new_tag_set.add_tag_overrides(self.filter(TagType.REJECT)) - return new_tag_set - - def combine(self, other: TagSetManager) -> TagSetManager: + def combine(self, other: "SchedulingTags") -> "SchedulingTags": if not self.can_combine(other): raise IncompatibleTagsException(self, other) - new_tag_set = TagSetManager() - # Add accept tags first, as they should be overridden by prefer, require and reject tags - new_tag_set.add_tag_overrides(other.filter(TagType.ACCEPT)) - new_tag_set.add_tag_overrides(self.filter(TagType.ACCEPT)) - # Next add preferred, as they should be overridden by require and reject tags - new_tag_set.add_tag_overrides(other.filter(TagType.PREFER)) - new_tag_set.add_tag_overrides(self.filter(TagType.PREFER)) - # Require and reject tags can be added in either order, as there's no overlap - new_tag_set.add_tag_overrides(other.filter(TagType.REQUIRE)) - new_tag_set.add_tag_overrides(self.filter(TagType.REQUIRE)) - new_tag_set.add_tag_overrides(other.filter(TagType.REJECT)) - new_tag_set.add_tag_overrides(self.filter(TagType.REJECT)) - return new_tag_set - - def match(self, other: TagSetManager) -> bool: - return (all(other.contains_tag(required) for required in self.filter(TagType.REQUIRE)) and - all(self.contains_tag(required) for required in other.filter(TagType.REQUIRE)) and - not any(other.contains_tag(rejected) for rejected in self.filter(TagType.REJECT)) and - not any(self.contains_tag(rejected) for rejected in other.filter(TagType.REJECT))) - - def contains_tag(self, tag) -> bool: - """ - Returns true if the name and value of the tag match. Ignores tag_type. - :param tag: - :return: - """ - return any(self.filter(tag_name=tag.name, tag_value=tag.value)) - def score(self, other: TagSetManager) -> bool: - """ - Computes a compatibility score between tag sets. - :param other: - :return: - """ - return (sum(int(tag.tag_type) * int(o.tag_type) for tag in self.tags for o in other.tags - if tag.name == o.name and tag.value == o.value) - # penalize tags that don't exist in the other - - sum(int(tag.tag_type) for tag in self.tags if not other.contains_tag(tag))) + new_tags = SchedulingTags() - def __eq__(self, other): - if not isinstance(other, TagSetManager): - # don't attempt to compare against unrelated types - return NotImplemented + # Add tags in the specific precedence order + for tag_type in [ + TagType.ACCEPT, + TagType.PREFER, + TagType.REQUIRE, + TagType.REJECT, + ]: + for tag in getattr(other, tag_type.name.lower()) or []: + new_tags.add_tag_override(tag_type, tag) + for tag in getattr(self, tag_type.name.lower()) or []: + new_tags.add_tag_override(tag_type, tag) - return self.tags == other.tags + return new_tags - def __repr__(self): - return f"{self.__class__} tags={[tag for tag in self.tags]}" - - @staticmethod - def from_dict(tags: list[dict]) -> TagSetManager: - tag_list = [] - for tag_val in tags.get('require') or []: - tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.REQUIRE)) - for tag_val in tags.get('prefer') or []: - tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.PREFER)) - for tag_val in tags.get('accept') or []: - tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.ACCEPT)) - for tag_val in tags.get('reject') or []: - tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.REJECT)) - return TagSetManager(tags=tag_list) - - def to_dict(self) -> dict: - result_dict = { - 'require': [tag.value for tag in self.tags if tag.tag_type == TagType.REQUIRE], - 'prefer': [tag.value for tag in self.tags if tag.tag_type == TagType.PREFER], - 'accept': [tag.value for tag in self.tags if tag.tag_type == TagType.ACCEPT], - 'reject': [tag.value for tag in self.tags if tag.tag_type == TagType.REJECT] - } - return result_dict + def match(self, other: "SchedulingTags") -> bool: + self_required = set(self.require or []) + other_required = set(other.require or []) + self_rejected = set(self.reject or []) + other_rejected = set(other.reject or []) + return ( + self_required.issubset(other.all_tag_values()) + and other_required.issubset(self.all_tag_values()) + and not self_rejected.intersection(other.all_tag_values()) + and not other_rejected.intersection(self.all_tag_values()) + ) -class Entity(object): + def score(self, other: "SchedulingTags") -> int: + return ( + sum( + int(tag.tag_type) * int(o.tag_type) + for tag in self.filter() + for o in other.filter() + if tag.value == o.value + ) + # penalize tags that don't exist in the other + - sum( + int(tag.tag_type) + for tag in self.tags + if tag.value not in other.all_tag_values() + ) + ) - merge_order = 0 - def __init__(self, loader, id=None, abstract=False, cores=None, mem=None, gpus=None, min_cores=None, min_mem=None, - min_gpus=None, max_cores=None, max_mem=None, max_gpus=None, env=None, params=None, resubmit=None, - tpv_tags=None, rank=None, inherits=None, context=None): - self.loader = loader +class Entity(BaseModel): + class Config: + arbitrary_types_allowed = True + + merge_order: ClassVar[int] = 0 + id: Optional[str] = None + abstract: Optional[bool] = False + inherits: Optional[str] = None + cores: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + mem: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + gpus: Annotated[Optional[int | str], TPVFieldMetadata()] = None + min_cores: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + min_mem: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + min_gpus: Annotated[Optional[int | str], TPVFieldMetadata()] = None + max_cores: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + max_mem: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + max_gpus: Annotated[Optional[int | str], TPVFieldMetadata()] = None + env: Annotated[ + Optional[List[Dict[str, str]]], TPVFieldMetadata(complex_property=True) + ] = None + params: Annotated[ + Optional[Dict[str, Any]], TPVFieldMetadata(complex_property=True) + ] = None + resubmit: Annotated[ + Optional[Dict[str, str]], TPVFieldMetadata(complex_property=True) + ] = Field(default_factory=dict) + rank: Annotated[Optional[str], TPVFieldMetadata()] = None + context: Optional[Dict[str, Any]] = Field(default_factory=dict) + evaluator: SkipJsonSchema[Optional[TPVCodeEvaluator]] = Field( + exclude=True, default=None + ) + tpv_tags: Optional[SchedulingTags] = Field( + alias="scheduling", default_factory=SchedulingTags + ) + + def __init__(self, **data: Any): + super().__init__(**data) + self.propagate_parent_properties(id=self.id, evaluator=self.evaluator) + + def propagate_parent_properties(self, id=None, evaluator=None): self.id = id - self.abstract = galaxy_util.asbool(abstract) - self.cores = cores - self.mem = mem - self.gpus = gpus - self.min_cores = min_cores - self.min_mem = min_mem - self.min_gpus = min_gpus - self.max_cores = max_cores - self.max_mem = max_mem - self.max_gpus = max_gpus - self.env = self.convert_env(env) - self.params = params - self.resubmit = resubmit - self.tpv_tags = TagSetManager.from_dict(tpv_tags or {}) - self.rank = rank - self.inherits = inherits - self.context = context - self.validate() - - def __deepcopy__(self, memodict={}): - # make sure we don't deepcopy the loader: https://github.com/galaxyproject/total-perspective-vortex/issues/53 - # xref: https://stackoverflow.com/a/15774013 - cls = self.__class__ - result = cls.__new__(cls) - memodict[id(self)] = result - for k, v in self.__dict__.items(): - if k == "loader": - setattr(result, k, v) - else: - setattr(result, k, copy.deepcopy(v, memodict)) - return result - - def process_complex_property(self, prop, context, func, stringify=False): - if isinstance(prop, str): - return func(prop, context) - elif isinstance(prop, dict): - evaluated_props = {key: self.process_complex_property(childprop, context, func, stringify=stringify) - for key, childprop in prop.items()} - return evaluated_props - elif isinstance(prop, list): - evaluated_props = [self.process_complex_property(childprop, context, func, stringify=stringify) - for childprop in prop] - return evaluated_props - else: - return str(prop) if stringify else prop # To handle special case of env vars provided as ints - - def compile_complex_property(self, prop): - return self.process_complex_property( - prop, None, lambda p, c: self.loader.compile_code_block(p, as_f_string=True)) + self.evaluator = evaluator + if evaluator: + self.precompile_properties(evaluator) + + def precompile_properties(self, evaluator: TPVCodeEvaluator): + # compile properties and check for errors + if evaluator: + for name, value in self: + field = self.model_fields[name] + if field.metadata and field.metadata[0]: + prop = field.metadata[0] + if isinstance(prop, TPVFieldMetadata): + if prop.complex_property: + evaluator.compile_complex_property(value) + else: + evaluator.compile_code_block(value) + + def __deepcopy__(self, memo: dict): + # make sure we don't deepcopy the evaluator: https://github.com/galaxyproject/total-perspective-vortex/issues/53 + # xref: https://stackoverflow.com/a/68746763/10971151 + memo[id(self.evaluator)] = self.evaluator + return super().__deepcopy__(memo) - def evaluate_complex_property(self, prop, context, stringify=False): - return self.process_complex_property( - prop, context, lambda p, c: self.loader.eval_code_block(p, c, as_f_string=True), stringify=stringify) - - def convert_env(self, env): + @staticmethod + def convert_env(env): if isinstance(env, dict): - env = [dict(name=k, value=v) for (k, v) in env.items()] + env = [dict(name=k, value=str(v)) for (k, v) in env.items()] return env - def validate(self): - """ - Validates each code block and makes sure the code can be compiled. - This process also results in the compiled code being cached by the loader, - so that future evaluations are faster. - """ - if self.cores: - self.loader.compile_code_block(self.cores) - if self.mem: - self.loader.compile_code_block(self.mem) - if self.gpus: - self.loader.compile_code_block(self.gpus) - if self.min_cores: - self.loader.compile_code_block(self.min_cores) - if self.min_mem: - self.loader.compile_code_block(self.min_mem) - if self.min_gpus: - self.loader.compile_code_block(self.min_gpus) - if self.max_cores: - self.loader.compile_code_block(self.max_cores) - if self.max_mem: - self.loader.compile_code_block(self.max_mem) - if self.max_gpus: - self.loader.compile_code_block(self.max_gpus) - if self.env: - self.compile_complex_property(self.env) - if self.params: - self.compile_complex_property(self.params) - if self.resubmit: - self.compile_complex_property(self.resubmit) - if self.rank: - self.loader.compile_code_block(self.rank) - - def __repr__(self): - return f"{self.__class__} id={self.id}, abstract={self.abstract}, cores={self.cores}, mem={self.mem}, " \ - f"gpus={self.gpus}, min_cores = {self.min_cores}, min_mem = {self.min_mem}, " \ - f"min_gpus = {self.min_gpus}, max_cores = {self.max_cores}, max_mem = {self.max_mem}, " \ - f"max_gpus = {self.max_gpus}, env={self.env}, params={self.params}, resubmit={self.resubmit}, " \ - f"tags={self.tpv_tags}, rank={self.rank[:10] if self.rank else ''}, inherits={self.inherits}, "\ - f"context={self.context}" - - def __eq__(self, other): - if not isinstance(other, self.__class__): - # don't attempt to compare against unrelated types - return NotImplemented - - return ( - self.id == other.id and - self.abstract == other.abstract and - self.cores == other.cores and - self.mem == other.mem and - self.gpus == other.gpus and - self.min_cores == other.min_cores and - self.min_mem == other.min_mem and - self.min_gpus == other.min_gpus and - self.max_cores == other.max_cores and - self.max_mem == other.max_mem and - self.max_gpus == other.max_gpus and - self.env == other.env and - self.params == other.params and - self.resubmit == other.resubmit and - self.tpv_tags == other.tpv_tags and - self.inherits == other.inherits and - self.context == other.context - ) + @model_validator(mode="before") + @classmethod + def preprocess(cls, values): + if values: + values["abstract"] = galaxy_util.asbool(values.get("abstract", False)) + values["env"] = Entity.convert_env(values.get("env")) + return values - def merge_env_list(self, original, replace): + @staticmethod + def merge_env_list(original, replace): for i, original_elem in enumerate(original): for j, replace_elem in enumerate(replace): - if (("name" in replace_elem and original_elem.get("name") == replace_elem["name"]) - or original_elem == replace_elem): + if ( + "name" in replace_elem + and original_elem.get("name") == replace_elem["name"] + ) or original_elem == replace_elem: original[i] = replace.pop(j) break original.extend(replace) return original - def override(self, entity): + @staticmethod + def override_single_property( + entity, entity1, entity2, property_name, field_copier=default_field_copier + ): + setattr(entity, property_name, field_copier(entity1, entity2, property_name)) + + def override(self, entity: "Entity") -> "Entity": if entity.merge_order <= self.merge_order: # Use the broader class as a base when copying. Useful in particular for Rules - new_entity = copy.copy(self) + new_entity = self.copy() else: - new_entity = copy.copy(entity) - new_entity.id = self.id or entity.id - new_entity.abstract = self.abstract and entity.abstract - new_entity.cores = self.cores if self.cores is not None else entity.cores - new_entity.mem = self.mem if self.mem is not None else entity.mem - new_entity.gpus = self.gpus if self.gpus is not None else entity.gpus - new_entity.min_cores = self.min_cores if self.min_cores is not None else entity.min_cores - new_entity.min_mem = self.min_mem if self.min_mem is not None else entity.min_mem - new_entity.min_gpus = self.min_gpus if self.min_gpus is not None else entity.min_gpus - new_entity.max_cores = self.max_cores if self.max_cores is not None else entity.max_cores - new_entity.max_mem = self.max_mem if self.max_mem is not None else entity.max_mem - new_entity.max_gpus = self.max_gpus if self.max_gpus is not None else entity.max_gpus - new_entity.env = self.merge_env_list(copy.deepcopy(entity.env) or [], copy.deepcopy(self.env) or []) - new_entity.params = copy.copy(entity.params) or {} - new_entity.params.update(self.params or {}) - new_entity.resubmit = copy.copy(entity.resubmit) or {} - new_entity.resubmit.update(self.resubmit or {}) - new_entity.rank = self.rank if self.rank is not None else entity.rank - new_entity.inherits = self.inherits if self.inherits is not None else entity.inherits - new_entity.context = copy.copy(entity.context) or {} - new_entity.context.update(self.context or {}) + new_entity = entity.copy() + self.override_single_property(new_entity, self, entity, "id") + self.override_single_property(new_entity, self, entity, "abstract") + self.override_single_property(new_entity, self, entity, "cores") + self.override_single_property(new_entity, self, entity, "mem") + self.override_single_property(new_entity, self, entity, "gpus") + self.override_single_property(new_entity, self, entity, "min_cores") + self.override_single_property(new_entity, self, entity, "min_mem") + self.override_single_property(new_entity, self, entity, "min_gpus") + self.override_single_property(new_entity, self, entity, "max_cores") + self.override_single_property(new_entity, self, entity, "max_mem") + self.override_single_property(new_entity, self, entity, "max_gpus") + self.override_single_property(new_entity, self, entity, "max_gpus") + self.override_single_property( + new_entity, + self, + entity, + "env", + field_copier=lambda e1, e2, p: self.merge_env_list( + copy.deepcopy(entity.env) or [], copy.deepcopy(self.env) or [] + ), + ) + self.override_single_property( + new_entity, self, entity, "params", field_copier=default_dict_copier + ) + self.override_single_property( + new_entity, self, entity, "resubmit", field_copier=default_dict_copier + ) + self.override_single_property(new_entity, self, entity, "rank") + self.override_single_property(new_entity, self, entity, "inherits") + self.override_single_property( + new_entity, self, entity, "context", field_copier=default_dict_copier + ) return new_entity def inherit(self, entity): @@ -388,52 +387,70 @@ def combine(self, entity): :return: """ new_entity = self.override(entity) - new_entity.id = f"{type(self).__name__}: {self.id}, {type(entity).__name__}: {entity.id}" + new_entity.id = ( + f"{type(self).__name__}: {self.id}, {type(entity).__name__}: {entity.id}" + ) new_entity.tpv_tags = entity.tpv_tags.combine(self.tpv_tags) return new_entity - def evaluate_resources(self, context): + def evaluate_resources(self, context: Dict[str, Any]): new_entity = copy.deepcopy(self) context.update(self.context or {}) if self.min_gpus is not None: - new_entity.min_gpus = self.loader.eval_code_block(self.min_gpus, context) - context['min_gpus'] = new_entity.min_gpus + new_entity.min_gpus = self.evaluator.eval_code_block(self.min_gpus, context) + context["min_gpus"] = new_entity.min_gpus if self.min_cores is not None: - new_entity.min_cores = self.loader.eval_code_block(self.min_cores, context) - context['min_cores'] = new_entity.min_cores + new_entity.min_cores = self.evaluator.eval_code_block( + self.min_cores, context + ) + context["min_cores"] = new_entity.min_cores if self.min_mem is not None: - new_entity.min_mem = self.loader.eval_code_block(self.min_mem, context) - context['min_mem'] = new_entity.min_mem + new_entity.min_mem = self.evaluator.eval_code_block(self.min_mem, context) + context["min_mem"] = new_entity.min_mem if self.max_gpus is not None: - new_entity.max_gpus = self.loader.eval_code_block(self.max_gpus, context) - context['max_gpus'] = new_entity.max_gpus + new_entity.max_gpus = self.evaluator.eval_code_block(self.max_gpus, context) + context["max_gpus"] = new_entity.max_gpus if self.max_cores is not None: - new_entity.max_cores = self.loader.eval_code_block(self.max_cores, context) - context['max_cores'] = new_entity.max_cores + new_entity.max_cores = self.evaluator.eval_code_block( + self.max_cores, context + ) + context["max_cores"] = new_entity.max_cores if self.max_mem is not None: - new_entity.max_mem = self.loader.eval_code_block(self.max_mem, context) - context['max_mem'] = new_entity.max_mem + new_entity.max_mem = self.evaluator.eval_code_block(self.max_mem, context) + context["max_mem"] = new_entity.max_mem if self.gpus is not None: - new_entity.gpus = self.loader.eval_code_block(self.gpus, context) + new_entity.gpus = self.evaluator.eval_code_block(self.gpus, context) # clamp gpus new_entity.gpus = max(new_entity.min_gpus or 0, new_entity.gpus or 0) - new_entity.gpus = min(new_entity.max_gpus, new_entity.gpus) if new_entity.max_gpus else new_entity.gpus - context['gpus'] = new_entity.gpus + new_entity.gpus = ( + min(new_entity.max_gpus, new_entity.gpus) + if new_entity.max_gpus + else new_entity.gpus + ) + context["gpus"] = new_entity.gpus if self.cores is not None: - new_entity.cores = self.loader.eval_code_block(self.cores, context) + new_entity.cores = self.evaluator.eval_code_block(self.cores, context) # clamp cores new_entity.cores = max(new_entity.min_cores or 0, new_entity.cores or 0) - new_entity.cores = min(new_entity.max_cores, new_entity.cores) if new_entity.max_cores else new_entity.cores - context['cores'] = new_entity.cores + new_entity.cores = ( + min(new_entity.max_cores, new_entity.cores) + if new_entity.max_cores + else new_entity.cores + ) + context["cores"] = new_entity.cores if self.mem is not None: - new_entity.mem = self.loader.eval_code_block(self.mem, context) + new_entity.mem = self.evaluator.eval_code_block(self.mem, context) # clamp mem new_entity.mem = max(new_entity.min_mem or 0, new_entity.mem or 0) - new_entity.mem = min(new_entity.max_mem, new_entity.mem or 0) if new_entity.max_mem else new_entity.mem - context['mem'] = new_entity.mem + new_entity.mem = ( + min(new_entity.max_mem, new_entity.mem or 0) + if new_entity.max_mem + else new_entity.mem + ) + context["mem"] = new_entity.mem return new_entity - def evaluate(self, context): + def evaluate(self, context: Dict[str, Any]): """ Evaluate expressions in entity properties that must be evaluated as late as possible, which is to say, after combining entity requirements. This includes env, params and resubmit, that rely on @@ -443,104 +460,107 @@ def evaluate(self, context): """ new_entity = self.evaluate_resources(context) if self.env: - new_entity.env = self.evaluate_complex_property(self.env, context, stringify=True) - context['env'] = new_entity.env + new_entity.env = self.evaluator.evaluate_complex_property(self.env, context) + context["env"] = new_entity.env if self.params: - new_entity.params = self.evaluate_complex_property(self.params, context) - context['params'] = new_entity.params + new_entity.params = self.evaluator.evaluate_complex_property( + self.params, context + ) + context["params"] = new_entity.params if self.resubmit: - new_entity.resubmit = self.evaluate_complex_property(self.resubmit, context) - context['resubmit'] = new_entity.resubmit + new_entity.resubmit = self.evaluator.evaluate_complex_property( + self.resubmit, context + ) + context["resubmit"] = new_entity.resubmit return new_entity - def rank_destinations(self, destinations, context): + def rank_destinations( + self, destinations: List["Destination"], context: Dict[str, Any] + ): if self.rank: - log.debug(f"Ranking destinations: {destinations} for entity: {self} using custom function") - context['candidate_destinations'] = destinations - return self.loader.eval_code_block(self.rank, context) + log.debug( + f"Ranking destinations: {destinations} for entity: {self} using custom" + " function" + ) + context["candidate_destinations"] = destinations + return self.evaluator.eval_code_block(self.rank, context) else: # Sort destinations by priority - log.debug(f"Ranking destinations: {destinations} for entity: {self} using default ranker") + log.debug( + f"Ranking destinations: {destinations} for entity: {self} using default" + " ranker" + ) return sorted(destinations, key=lambda d: d.score(self), reverse=True) - def to_dict(self): - dict_obj = { - 'id': self.id, - 'abstract': self.abstract, - 'cores': self.cores, - 'mem': self.mem, - 'gpus': self.gpus, - 'min_cores': self.min_cores, - 'min_mem': self.min_mem, - 'min_gpus': self.min_gpus, - 'max_cores': self.max_cores, - 'max_mem': self.max_mem, - 'max_gpus': self.max_gpus, - 'env': self.env, - 'params': self.params, - 'resubmit': self.resubmit, - 'scheduling': self.tpv_tags.to_dict(), - 'inherits': self.inherits, - 'context': self.context - } - return dict_obj + def model_dump(self, **kwargs): + # Ensure by_alias is set to True to use the field aliases during serialization + kwargs.setdefault("by_alias", True) + return super().model_dump(**kwargs) + def dict(self, **kwargs): + # by_alias is set to True to use the field aliases during serialization + kwargs.setdefault("by_alias", True) + return super().dict(**kwargs) -class EntityWithRules(Entity): - merge_order = 1 - - def __init__(self, loader, id=None, abstract=False, cores=None, mem=None, gpus=None, min_cores=None, min_mem=None, - min_gpus=None, max_cores=None, max_mem=None, max_gpus=None, env=None, - params=None, resubmit=None, tpv_tags=None, rank=None, inherits=None, context=None, rules=None): - super().__init__(loader, id=id, abstract=abstract, cores=cores, mem=mem, gpus=gpus, min_cores=min_cores, - min_mem=min_mem, min_gpus=min_gpus, max_cores=max_cores, max_mem=max_mem, max_gpus=max_gpus, - env=env, params=params, resubmit=resubmit, tpv_tags=tpv_tags, rank=rank, inherits=inherits, - context=context) - self.rules = self.validate_rules(rules) - - def validate_rules(self, rules: list) -> list: - validated = {} - for rule in rules or []: - try: - validated_rule = Rule.from_dict(self.loader, rule) - validated[validated_rule.id] = validated_rule - except Exception: - log.exception(f"Could not load rule for entity: {self.__class__} with id: {self.id} and data: {rule}") - raise - return validated +class Rule(Entity): + rule_counter: ClassVar[int] = 0 + id: Optional[str] = Field(default_factory=lambda: Rule.set_default_id()) + if_condition: str | bool = Field(alias="if") + execute: Optional[str] = None + fail: Optional[str] = None @classmethod - def from_dict(cls: type, loader, entity_dict): - return cls( - loader=loader, - id=entity_dict.get('id'), - abstract=entity_dict.get('abstract'), - cores=entity_dict.get('cores'), - mem=entity_dict.get('mem'), - gpus=entity_dict.get('gpus'), - min_cores=entity_dict.get('min_cores'), - min_mem=entity_dict.get('min_mem'), - min_gpus=entity_dict.get('min_gpus'), - max_cores=entity_dict.get('max_cores'), - max_mem=entity_dict.get('max_mem'), - max_gpus=entity_dict.get('max_gpus'), - env=entity_dict.get('env'), - params=entity_dict.get('params'), - resubmit=entity_dict.get('resubmit'), - tpv_tags=entity_dict.get('scheduling'), - rank=entity_dict.get('rank'), - inherits=entity_dict.get('inherits'), - context=entity_dict.get('context'), - rules=entity_dict.get('rules') - ) - - def to_dict(self): - dict_obj = super().to_dict() - dict_obj['rules'] = [rule.to_dict() for rule in self.rules.values()] - return dict_obj + def set_default_id(cls): + cls.rule_counter += 1 + return f"tpv_rule_{cls.rule_counter}" def override(self, entity): + new_entity = super().override(entity) + if isinstance(entity, Rule): + self.override_single_property(new_entity, self, entity, "if_condition") + self.override_single_property(new_entity, self, entity, "execute") + self.override_single_property(new_entity, self, entity, "fail") + return new_entity + + def is_matching(self, context): + if self.evaluator.eval_code_block(self.if_condition, context): + return True + else: + return False + + def evaluate(self, context): + if self.fail: + from galaxy.jobs.mapper import JobMappingException + + raise JobMappingException( + self.evaluator.eval_code_block(self.fail, context, as_f_string=True) + ) + if self.execute: + self.evaluator.eval_code_block(self.execute, context, exec_only=True) + # return any changes made to the entity + return context["entity"] + return self + + +class EntityWithRules(Entity): + merge_order: ClassVar[int] = 1 + rules: Optional[Dict[str, Rule]] = Field(default_factory=dict) + + def propagate_parent_properties(self, id=None, evaluator=None): + super().propagate_parent_properties(id=id, evaluator=evaluator) + for rule in self.rules.values(): + rule.evaluator = evaluator + + @model_validator(mode="before") + @classmethod + def deserialize_rules(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "rules" in values and isinstance(values["rules"], list): + rules = (Rule(**r) for r in values["rules"]) + values["rules"] = {rule.id: rule for rule in rules} + return values + + def override(self, entity: Entity): new_entity = super().override(entity) new_entity.rules = copy.deepcopy(entity.rules) new_entity.rules.update(self.rules or {}) @@ -549,7 +569,7 @@ def override(self, entity): new_entity.rules[rule.id] = rule.inherit(entity.rules[rule.id]) return new_entity - def evaluate_rules(self, context): + def evaluate_rules(self, context: Dict[str, str]): new_entity = copy.deepcopy(self) context.update(new_entity.context or {}) for rule in self.rules.values(): @@ -560,184 +580,87 @@ def evaluate_rules(self, context): new_entity.cores = rule.cores or new_entity.cores new_entity.mem = rule.mem or new_entity.mem new_entity.id = f"{new_entity.id}, Rule: {rule.id}" - context.update({ - 'entity': new_entity - }) + context.update({"entity": new_entity}) return new_entity - def evaluate(self, context): + def evaluate(self, context: Dict[str, str]): new_entity = self.evaluate_rules(context) return super(EntityWithRules, new_entity).evaluate(context) - def __repr__(self): - return super().__repr__() + f", rules={self.rules}" - - def __eq__(self, other): - return super().__eq__(other) and ( - self.rules == other.rules - ) - class Tool(EntityWithRules): - - merge_order = 2 + merge_order: ClassVar[int] = 2 + pass class Role(EntityWithRules): - - merge_order = 3 + merge_order: ClassVar[int] = 3 class User(EntityWithRules): - - merge_order = 4 + merge_order: ClassVar[int] = 4 class Destination(EntityWithRules): - - merge_order = 5 - - def __init__(self, loader, id=None, abstract=False, runner=None, dest_name=None, cores=None, mem=None, gpus=None, - min_cores=None, min_mem=None, min_gpus=None, max_cores=None, max_mem=None, max_gpus=None, - min_accepted_cores=None, min_accepted_mem=None, min_accepted_gpus=None, - max_accepted_cores=None, max_accepted_mem=None, max_accepted_gpus=None, env=None, params=None, - resubmit=None, tpv_dest_tags=None, inherits=None, context=None, rules=None, handler_tags=None): - self.runner = runner - self.dest_name = dest_name or id - self.min_accepted_cores = min_accepted_cores - self.min_accepted_mem = min_accepted_mem - self.min_accepted_gpus = min_accepted_gpus - self.max_accepted_cores = max_accepted_cores - self.max_accepted_mem = max_accepted_mem - self.max_accepted_gpus = max_accepted_gpus - self.tpv_dest_tags = TagSetManager.from_dict(tpv_dest_tags or {}) - # Handler tags refer to Galaxy's job handler level tags - self.handler_tags = handler_tags - super().__init__(loader, id=id, abstract=abstract, cores=cores, mem=mem, gpus=gpus, min_cores=min_cores, - min_mem=min_mem, min_gpus=min_gpus, max_cores=max_cores, max_mem=max_mem, max_gpus=max_gpus, - env=env, params=params, resubmit=resubmit, tpv_tags=None, inherits=inherits, context=context, - rules=rules) - - @staticmethod - def from_dict(loader, entity_dict): - return Destination( - loader=loader, - id=entity_dict.get('id'), - abstract=entity_dict.get('abstract'), - runner=entity_dict.get('runner'), - dest_name=entity_dict.get('destination_name_override'), - cores=entity_dict.get('cores'), - mem=entity_dict.get('mem'), - gpus=entity_dict.get('gpus'), - min_cores=entity_dict.get('min_cores'), - min_mem=entity_dict.get('min_mem'), - min_gpus=entity_dict.get('min_gpus'), - max_cores=entity_dict.get('max_cores'), - max_mem=entity_dict.get('max_mem'), - max_gpus=entity_dict.get('max_gpus'), - min_accepted_cores=entity_dict.get('min_accepted_cores'), - min_accepted_mem=entity_dict.get('min_accepted_mem'), - min_accepted_gpus=entity_dict.get('min_accepted_gpus'), - max_accepted_cores=entity_dict.get('max_accepted_cores'), - max_accepted_mem=entity_dict.get('max_accepted_mem'), - max_accepted_gpus=entity_dict.get('max_accepted_gpus'), - env=entity_dict.get('env'), - params=entity_dict.get('params'), - resubmit=entity_dict.get('resubmit'), - tpv_dest_tags=entity_dict.get('scheduling'), - inherits=entity_dict.get('inherits'), - context=entity_dict.get('context'), - rules=entity_dict.get('rules'), - handler_tags=entity_dict.get('tags') - ) - - def to_dict(self): - dict_obj = super().to_dict() - dict_obj['runner'] = self.runner - dict_obj['destination_name_override'] = self.dest_name - dict_obj['min_accepted_cores'] = self.min_accepted_cores - dict_obj['min_accepted_mem'] = self.min_accepted_mem - dict_obj['min_accepted_gpus'] = self.min_accepted_gpus - dict_obj['max_accepted_cores'] = self.max_accepted_cores - dict_obj['max_accepted_mem'] = self.max_accepted_mem - dict_obj['max_accepted_gpus'] = self.max_accepted_gpus - dict_obj['scheduling'] = self.tpv_dest_tags.to_dict() - dict_obj['tags'] = self.handler_tags - return dict_obj - - def __eq__(self, other): - if not isinstance(other, Destination): - # don't attempt to compare against unrelated types - return NotImplemented - - return super().__eq__(other) and ( - self.runner == other.runner and - self.dest_name == other.dest_name and - self.min_accepted_cores == other.min_accepted_cores and - self.min_accepted_mem == other.min_accepted_mem and - self.min_accepted_gpus == other.min_accepted_gpus and - self.max_accepted_cores == other.max_accepted_cores and - self.max_accepted_mem == other.max_accepted_mem and - self.max_accepted_gpus == other.max_accepted_gpus and - self.tpv_dest_tags == other.tpv_dest_tags and - self.handler_tags == other.handler_tags - ) - - def __repr__(self): - return f"runner={self.runner}, dest_name={self.dest_name}, min_accepted_cores={self.min_accepted_cores}, "\ - f"min_accepted_mem={self.min_accepted_mem}, min_accepted_gpus={self.min_accepted_gpus}, "\ - f"max_accepted_cores={self.max_accepted_cores}, max_accepted_mem={self.max_accepted_mem}, "\ - f"max_accepted_gpus={self.max_accepted_gpus}, tpv_dest_tags={self.tpv_dest_tags}, "\ - f"handler_tags={self.handler_tags}" + super().__repr__() - - def override(self, entity): + merge_order: ClassVar[int] = 5 + runner: Optional[str] = None + max_accepted_cores: Optional[int | float] = None + max_accepted_mem: Optional[int | float] = None + max_accepted_gpus: Optional[int] = None + min_accepted_cores: Optional[int | float] = None + min_accepted_mem: Optional[int | float] = None + min_accepted_gpus: Optional[int] = None + dest_name: Optional[str] = Field(alias="destination_name_override", default=None) + # tpv_tags track what tags the entity being scheduled requested, while tpv_dest_tags track what the destination + # supports. When serializing a Destination, we don't need tpv_tags, only tpv_dest_tags. + tpv_tags: SkipJsonSchema[Optional[SchedulingTags]] = Field( + exclude=True, default_factory=SchedulingTags + ) + tpv_dest_tags: Optional[SchedulingTags] = Field( + alias="scheduling", default_factory=SchedulingTags + ) + handler_tags: Annotated[ + Optional[List[str]], TPVFieldMetadata(complex_property=True) + ] = Field(alias="tags", default_factory=list) + + def propagate_parent_properties(self, id=None, evaluator=None): + super().propagate_parent_properties(id=id, evaluator=evaluator) + self.dest_name = self.dest_name or self.id + + def override(self, entity: Entity): new_entity = super().override(entity) - new_entity.runner = self.runner if self.runner is not None else getattr(entity, 'runner', None) - new_entity.dest_name = self.dest_name if self.dest_name is not None else getattr(entity, 'dest_name', None) - new_entity.min_accepted_cores = (self.min_accepted_cores if self.min_accepted_cores is not None - else getattr(entity, 'min_accepted_cores', None)) - new_entity.min_accepted_mem = (self.min_accepted_mem if self.min_accepted_mem is not None - else getattr(entity, 'min_accepted_mem', None)) - new_entity.min_accepted_gpus = (self.min_accepted_gpus if self.min_accepted_gpus is not None - else getattr(entity, 'min_accepted_gpus', None)) - new_entity.max_accepted_cores = (self.max_accepted_cores if self.max_accepted_cores is not None - else getattr(entity, 'max_accepted_cores', None)) - new_entity.max_accepted_mem = (self.max_accepted_mem if self.max_accepted_mem is not None - else getattr(entity, 'max_accepted_mem', None)) - new_entity.max_accepted_gpus = (self.max_accepted_gpus if self.max_accepted_gpus is not None - else getattr(entity, 'max_accepted_gpus', None)) - new_entity.handler_tags = self.handler_tags or getattr(entity, 'handler_tags', None) + self.override_single_property(new_entity, self, entity, "runner") + self.override_single_property(new_entity, self, entity, "dest_name") + self.override_single_property(new_entity, self, entity, "min_accepted_cores") + self.override_single_property(new_entity, self, entity, "min_accepted_mem") + self.override_single_property(new_entity, self, entity, "min_accepted_gpus") + self.override_single_property(new_entity, self, entity, "max_accepted_cores") + self.override_single_property(new_entity, self, entity, "max_accepted_mem") + self.override_single_property(new_entity, self, entity, "max_accepted_gpus") + self.override_single_property(new_entity, self, entity, "handler_tags") return new_entity - def validate(self): - """ - Validates each code block and makes sure the code can be compiled. - This process also results in the compiled code being cached by the loader, - so that future evaluations are faster. - """ - super().validate() - if self.dest_name: - self.loader.compile_code_block(self.dest_name, as_f_string=True) - if self.handler_tags: - self.compile_complex_property(self.handler_tags) - - def evaluate(self, context): + def evaluate(self, context: Dict[str, Any]): new_entity = super(Destination, self).evaluate(context) if self.dest_name is not None: - new_entity.dest_name = self.loader.eval_code_block(self.dest_name, context, as_f_string=True) - context['dest_name'] = new_entity.dest_name + new_entity.dest_name = self.evaluator.eval_code_block( + self.dest_name, context, as_f_string=True + ) + context["dest_name"] = new_entity.dest_name if self.handler_tags is not None: - new_entity.handler_tags = self.evaluate_complex_property(self.handler_tags, context) - context['handler_tags'] = new_entity.handler_tags + new_entity.handler_tags = self.evaluator.evaluate_complex_property( + self.handler_tags, context + ) + context["handler_tags"] = new_entity.handler_tags return new_entity - def inherit(self, entity): + def inherit(self, entity: Entity): new_entity = super().inherit(entity) if entity: new_entity.tpv_dest_tags = self.tpv_dest_tags.inherit(entity.tpv_dest_tags) return new_entity - def matches(self, entity, context): + def matches(self, entity: Entity, context: Dict[str, Any]): """ The match operation checks whether @@ -754,17 +677,41 @@ def matches(self, entity, context): """ if self.abstract: return False - if self.max_accepted_cores is not None and entity.cores is not None and self.max_accepted_cores < entity.cores: + if ( + self.max_accepted_cores is not None + and entity.cores is not None + and self.max_accepted_cores < entity.cores + ): return False - if self.max_accepted_mem is not None and entity.mem is not None and self.max_accepted_mem < entity.mem: + if ( + self.max_accepted_mem is not None + and entity.mem is not None + and self.max_accepted_mem < entity.mem + ): return False - if self.max_accepted_gpus is not None and entity.gpus is not None and self.max_accepted_gpus < entity.gpus: + if ( + self.max_accepted_gpus is not None + and entity.gpus is not None + and self.max_accepted_gpus < entity.gpus + ): return False - if self.min_accepted_cores is not None and entity.cores is not None and self.min_accepted_cores > entity.cores: + if ( + self.min_accepted_cores is not None + and entity.cores is not None + and self.min_accepted_cores > entity.cores + ): return False - if self.min_accepted_mem is not None and entity.mem is not None and self.min_accepted_mem > entity.mem: + if ( + self.min_accepted_mem is not None + and entity.mem is not None + and self.min_accepted_mem > entity.mem + ): return False - if self.min_accepted_gpus is not None and entity.gpus is not None and self.min_accepted_gpus > entity.gpus: + if ( + self.min_accepted_gpus is not None + and entity.gpus is not None + and self.min_accepted_gpus > entity.gpus + ): return False return entity.tpv_tags.match(self.tpv_dest_tags or {}) @@ -780,89 +727,33 @@ def score(self, entity): return score -class Rule(Entity): - - rule_counter = 0 - merge_order = 0 - - def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, min_cores=None, min_mem=None, min_gpus=None, - max_cores=None, max_mem=None, max_gpus=None, env=None, params=None, resubmit=None, - tpv_tags=None, inherits=None, context=None, match=None, execute=None, fail=None): - if not id: - Rule.rule_counter += 1 - id = f"tpv_rule_{Rule.rule_counter}" - super().__init__(loader, id=id, abstract=False, cores=cores, mem=mem, gpus=gpus, min_cores=min_cores, - min_mem=min_mem, min_gpus=min_gpus, max_cores=max_cores, max_mem=max_mem, max_gpus=max_gpus, - env=env, params=params, resubmit=resubmit, tpv_tags=tpv_tags, context=context, - inherits=inherits) - self.match = match - self.execute = execute - self.fail = fail - if self.match: - self.loader.compile_code_block(self.match) - if self.execute: - self.loader.compile_code_block(self.execute, exec_only=True) - if self.fail: - self.loader.compile_code_block(self.fail, as_f_string=True) - - @staticmethod - def from_dict(loader, entity_dict): - return Rule( - loader=loader, - id=entity_dict.get('id'), - cores=entity_dict.get('cores'), - mem=entity_dict.get('mem'), - gpus=entity_dict.get('gpus'), - min_cores=entity_dict.get('min_cores'), - min_mem=entity_dict.get('min_mem'), - min_gpus=entity_dict.get('min_gpus'), - max_cores=entity_dict.get('max_cores'), - max_mem=entity_dict.get('max_mem'), - max_gpus=entity_dict.get('max_gpus'), - env=entity_dict.get('env'), - params=entity_dict.get('params'), - resubmit=entity_dict.get('resubmit'), - tpv_tags=entity_dict.get('scheduling'), - inherits=entity_dict.get('inherits'), - context=entity_dict.get('context'), - # TODO: Remove deprecated match clause in future - match=entity_dict.get('if') or entity_dict.get('match'), - execute=entity_dict.get('execute'), - fail=entity_dict.get('fail') - ) - - def to_dict(self): - dict_obj = super().to_dict() - dict_obj['if'] = self.match - dict_obj['execute'] = self.execute - dict_obj['fail'] = self.fail - return dict_obj - - def override(self, entity): - new_entity = super().override(entity) - new_entity.match = self.match if self.match is not None else getattr(entity, 'match', None) - new_entity.execute = self.execute if self.execute is not None else getattr(entity, 'execute', None) - new_entity.fail = self.fail if self.fail is not None else getattr(entity, 'fail', None) - return new_entity - - def __repr__(self): - return super().__repr__() + f", if={self.match[:10] if self.match else ''}, " \ - f"execute={self.execute[:10] if self.execute else ''}, " \ - f"fail={self.fail[:10] if self.fail else ''}" - - def is_matching(self, context): - if self.loader.eval_code_block(self.match, context): - return True - else: - return False - - def evaluate(self, context): - if self.fail: - from galaxy.jobs.mapper import JobMappingException - raise JobMappingException( - self.loader.eval_code_block(self.fail, context, as_f_string=True)) - if self.execute: - self.loader.eval_code_block(self.execute, context, exec_only=True) - # return any changes made to the entity - return context['entity'] +class GlobalConfig(BaseModel): + default_inherits: Optional[str] = None + context: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +class TPVConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + global_config: Optional[GlobalConfig] = Field( + alias="global", default_factory=GlobalConfig + ) + evaluator: SkipJsonSchema[Optional[TPVCodeEvaluator]] = Field( + exclude=True, default=None + ) + tools: Optional[Dict[str, Tool]] = Field(default_factory=dict) + users: Optional[Dict[str, User]] = Field(default_factory=dict) + roles: Optional[Dict[str, Role]] = Field(default_factory=dict) + destinations: Optional[Dict[str, Destination]] = Field(default_factory=dict) + + @model_validator(mode="after") + def propagate_parent_properties(self): + if self.evaluator: + for id, tool in self.tools.items(): + tool.propagate_parent_properties(id=id, evaluator=self.evaluator) + for id, user in self.users.items(): + user.propagate_parent_properties(id=id, evaluator=self.evaluator) + for id, role in self.roles.items(): + role.propagate_parent_properties(id=id, evaluator=self.evaluator) + for id, destination in self.destinations.items(): + destination.propagate_parent_properties(id=id, evaluator=self.evaluator) return self diff --git a/tpv/core/evaluator.py b/tpv/core/evaluator.py new file mode 100644 index 0000000..4731ac9 --- /dev/null +++ b/tpv/core/evaluator.py @@ -0,0 +1,45 @@ +import abc +from typing import Any, Dict + + +class TPVCodeEvaluator(abc.ABC): + + @abc.abstractmethod + def compile_code_block(self, code: str, as_f_string=False, exec_only=False): + pass + + @abc.abstractmethod + def eval_code_block( + self, code: str, context: Dict[str, Any], as_f_string=False, exec_only=False + ): + pass + + def process_complex_property(self, prop: Any, context: Dict[str, Any], func): + if isinstance(prop, str): + return func(prop, context) + elif isinstance(prop, dict): + evaluated_props = { + key: self.process_complex_property(childprop, context, func) + for key, childprop in prop.items() + } + return evaluated_props + elif isinstance(prop, list): + evaluated_props = [ + self.process_complex_property(childprop, context, func) + for childprop in prop + ] + return evaluated_props + else: + return prop + + def compile_complex_property(self, prop): + return self.process_complex_property( + prop, None, lambda p, c: self.compile_code_block(p, as_f_string=True) + ) + + def evaluate_complex_property(self, prop, context: Dict[str, Any]): + return self.process_complex_property( + prop, + context, + lambda p, c: self.eval_code_block(p, c, as_f_string=True), + ) diff --git a/tpv/core/loader.py b/tpv/core/loader.py index 99f469c..3e189e9 100644 --- a/tpv/core/loader.py +++ b/tpv/core/loader.py @@ -1,11 +1,13 @@ from __future__ import annotations + import ast import functools import logging +from typing import Dict -from . import helpers -from . import util -from .entities import Tool, User, Role, Destination, Entity +from . import helpers, util +from .entities import Entity, GlobalConfig, TPVConfig +from .evaluator import TPVCodeEvaluator log = logging.getLogger(__name__) @@ -14,40 +16,52 @@ class InvalidParentException(Exception): pass -class TPVConfigLoader(object): +class TPVConfigLoader(TPVCodeEvaluator): + + def __init__(self, tpv_config: TPVConfig): + self.compile_code_block = functools.lru_cache(maxsize=None)( + self.__compile_code_block + ) + self.config = TPVConfig(evaluator=self, **tpv_config) + self.process_entities(self.config) - def __init__(self, tpv_config: dict): - self.compile_code_block = functools.lru_cache(maxsize=None)(self.__compile_code_block) - self.global_settings = tpv_config.get('global', {}) - entities = self.load_entities(tpv_config) - self.tools = entities.get('tools') - self.users = entities.get('users') - self.roles = entities.get('roles') - self.destinations = entities.get('destinations') + def compile_code_block(self, code, as_f_string=False, exec_only=False): + # interface method, replaced with instance based lru cache in constructor + pass def __compile_code_block(self, code, as_f_string=False, exec_only=False): if as_f_string: code_str = "f'''" + str(code) + "'''" else: code_str = str(code) - block = ast.parse(code_str, mode='exec') + block = ast.parse(code_str, mode="exec") if exec_only: - return compile(block, '', mode='exec'), None + return compile(block, "", mode="exec"), None else: # assumes last node is an expression last = ast.Expression(block.body.pop().value) - return compile(block, '', mode='exec'), compile(last, '', mode='eval') + return compile(block, "", mode="exec"), compile( + last, "", mode="eval" + ) # https://stackoverflow.com/a/39381428 def eval_code_block(self, code, context, as_f_string=False, exec_only=False): - exec_block, eval_block = self.compile_code_block(code, as_f_string=as_f_string, exec_only=exec_only) + exec_block, eval_block = self.compile_code_block( + code, as_f_string=as_f_string, exec_only=exec_only + ) locals = dict(globals()) locals.update(context) - locals.update({ - 'helpers': helpers, - # Don't unnecessarily compute input_size unless it's referred to - 'input_size': helpers.input_size(context['job']) if 'input_size' in str(code) else 0 - }) + locals.update( + { + "helpers": helpers, + # Don't unnecessarily compute input_size unless it's referred to + "input_size": ( + helpers.input_size(context["job"]) + if "input_size" in str(code) + else 0 + ), + } + ) exec(exec_block, locals) if eval_block: return eval(eval_block, locals) @@ -58,8 +72,10 @@ def process_inheritance(self, entity_list: dict[str, Entity], entity: Entity): if entity.inherits: parent_entity = entity_list.get(entity.inherits) if not parent_entity: - raise InvalidParentException(f"The specified parent: {entity.inherits} for" - f" entity: {entity} does not exist") + raise InvalidParentException( + f"The specified parent: {entity.inherits} for" + f" entity: {entity} does not exist" + ) return entity.inherit(self.process_inheritance(entity_list, parent_entity)) # do not process default inheritance here, only at runtime, as multiple can cause default inheritance # to override later matches. @@ -69,38 +85,26 @@ def recompute_inheritance(self, entities: dict[str, Entity]): for key, entity in entities.items(): entities[key] = self.process_inheritance(entities, entity) - def validate_entities(self, entity_class: type, entity_list: dict) -> dict: - # This code relies on dict ordering guarantees provided since python 3.6 - validated = {} - for entity_id, entity_dict in entity_list.items(): - try: - if not entity_dict: - entity_dict = {} - entity_dict['id'] = entity_id - validated[entity_id] = entity_class.from_dict(self, entity_dict) - except Exception: - log.exception(f"Could not load entity of type: {entity_class} with data: {entity_dict}") - raise - self.recompute_inheritance(validated) - return validated - - def load_entities(self, tpv_config: dict) -> dict: - validated = { - 'tools': self.validate_entities(Tool, tpv_config.get('tools', {})), - 'users': self.validate_entities(User, tpv_config.get('users', {})), - 'roles': self.validate_entities(Role, tpv_config.get('roles', {})), - 'destinations': self.validate_entities(Destination, tpv_config.get('destinations', {})) - } - return validated - - def inherit_globals(self, globals_other): - if globals_other: - self.global_settings.update({'default_inherits': globals_other.get('default_inherits')} - if globals_other.get('default_inherits') else {}) - self.global_settings['context'] = self.global_settings.get('context') or {} - self.global_settings['context'].update(globals_other.get('context') or {}) + def validate_entities(self, entities: Dict[str, Entity]) -> dict: + self.recompute_inheritance(entities) - def inherit_existing_entities(self, entities_current, entities_new): + def process_entities(self, tpv_config: TPVConfig) -> dict: + self.validate_entities(tpv_config.tools), + self.validate_entities(tpv_config.users), + self.validate_entities(tpv_config.roles), + self.validate_entities(tpv_config.destinations) + + def inherit_globals(self, globals_other: GlobalConfig): + if globals_other: + self.config.global_config.default_inherits = ( + globals_other.default_inherits + or self.config.global_config.default_inherits + ) + self.config.global_config.context.update(globals_other.context) + + def inherit_existing_entities( + self, entities_current: dict[str, Entity], entities_new: dict[str, Entity] + ): for entity in entities_new.values(): if entities_current.get(entity.id): current_entity = entities_current.get(entity.id) @@ -111,14 +115,21 @@ def inherit_existing_entities(self, entities_current, entities_new): entities_current[entity.id] = entity self.recompute_inheritance(entities_current) + def merge_config(self, config: TPVConfig): + self.inherit_globals(config.global_config) + self.inherit_existing_entities(self.config.tools, config.tools) + self.inherit_existing_entities(self.config.users, config.users) + self.inherit_existing_entities(self.config.roles, config.roles) + self.inherit_existing_entities(self.config.destinations, config.destinations) + def merge_loader(self, loader: TPVConfigLoader): - self.inherit_globals(loader.global_settings) - self.inherit_existing_entities(self.tools, loader.tools) - self.inherit_existing_entities(self.users, loader.users) - self.inherit_existing_entities(self.roles, loader.roles) - self.inherit_existing_entities(self.destinations, loader.destinations) + self.merge_config(loader.config) @staticmethod def from_url_or_path(url_or_path: str): tpv_config = util.load_yaml_from_url_or_path(url_or_path) - return TPVConfigLoader(tpv_config) + try: + return TPVConfigLoader(tpv_config) + except Exception as e: + log.exception(f"Error loading TPV config: {url_or_path}") + raise e diff --git a/tpv/core/mapper.py b/tpv/core/mapper.py index 6d5db57..66cb4a7 100644 --- a/tpv/core/mapper.py +++ b/tpv/core/mapper.py @@ -2,7 +2,7 @@ import logging import re -from .entities import Tool, TryNextDestinationOrFail, TryNextDestinationOrWait +from .entities import Entity, Tool, TryNextDestinationOrFail, TryNextDestinationOrWait from .loader import TPVConfigLoader from galaxy.jobs import JobDestination @@ -15,14 +15,10 @@ class EntityToDestinationMapper(object): def __init__(self, loader: TPVConfigLoader): self.loader = loader - self.entities = { - "tools": loader.tools, - "users": loader.users, - "roles": loader.roles - } - self.destinations = loader.destinations - self.default_inherits = loader.global_settings.get('default_inherits') - self.global_context = loader.global_settings.get('context') + self.config = loader.config + self.destinations = self.config.destinations + self.default_inherits = self.config.global_config.default_inherits + self.global_context = self.config.global_config.context self.lookup_tool_regex = functools.lru_cache(maxsize=None)(self.__compile_tool_regex) self.inherit_matching_entities = functools.lru_cache(maxsize=None)(self.__inherit_matching_entities) @@ -33,7 +29,7 @@ def __compile_tool_regex(self, key): log.error(f"Failed to compile regex: {key}") raise - def _find_entities_matching_id(self, entity_list, entity_name): + def _find_entities_matching_id(self, entity_list: dict[str, Entity], entity_name: str): default_inherits = self.__get_default_inherits(entity_list) if default_inherits: matches = [default_inherits] @@ -49,12 +45,12 @@ def _find_entities_matching_id(self, entity_list, entity_name): matches.append(match) return matches - def __inherit_matching_entities(self, entity_type, entity_name): - entity_list = self.entities.get(entity_type) + def __inherit_matching_entities(self, entity_type: str, entity_name: str): + entity_list = getattr(self.config, entity_type) matches = self._find_entities_matching_id(entity_list, entity_name) return self.inherit_entities(matches) - def __get_default_inherits(self, entity_list): + def __get_default_inherits(self, entity_list: dict[str, Entity]): if self.default_inherits: default_match = entity_list.get(self.default_inherits) if default_match: @@ -103,7 +99,7 @@ def to_galaxy_destination(self, destination): def _find_matching_entities(self, tool, user): tool_entity = self.inherit_matching_entities("tools", tool.id) if not tool_entity: - tool_entity = Tool.from_dict(self.loader, {'id': tool.id}) + tool_entity = Tool(loader=self.loader, id=tool.id) entity_list = [tool_entity]