Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to to_dict functions to Destination and TagSetManager classes #119

Merged
merged 6 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/test_entity.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import unittest
from tpv.rules import gateway
from tpv.core.entities import Destination
from tpv.core.entities import Tag
from tpv.core.entities import TagType
from tpv.core.entities import Tool
from tpv.core.loader import TPVConfigLoader
from tpv.commands.test import mock_galaxy


Expand Down Expand Up @@ -38,6 +41,33 @@ def test_all_entities_refer_to_same_loader(self):
for rule in evaluated_entity.rules:
assert rule.loader == original_loader

def test_destination_to_dict(self):

tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml')
loader = TPVConfigLoader.from_url_or_path(tpv_config)

# create a destination
destination = loader.destinations["k8s_environment"]
# serialize the destination
serialized_destination = destination.to_dict()
# deserialize the same destination
deserialized_destination = Destination.from_dict(loader, serialized_destination)
# make sure the deserialized destination is the same as the original
self.assertEqual(deserialized_destination, destination)

def test_tool_to_dict(self):
tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml')
loader = TPVConfigLoader.from_url_or_path(tpv_config)

# create a tool
tool = loader.tools["limbo"]
# serialize the tool
serialized_destination = tool.to_dict()
# deserialize the same tool
deserialized_destination = Tool.from_dict(loader, serialized_destination)
# make sure the deserialized tool is the same as the original
self.assertEqual(deserialized_destination, tool)

def test_tag_equivalence(self):
tag1 = Tag("tag_name", "tag_value", TagType.REQUIRE)
tag2 = Tag("tag_name2", "tag_value", TagType.REQUIRE)
Expand Down
112 changes: 112 additions & 0 deletions tpv/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def score(self, other: TagSetManager) -> bool:
# penalize tags that don't exist in the other
- sum(int(tag.tag_type) for tag in self.tags if not other.contains_tag(tag)))

def __eq__(self, other):
if not isinstance(other, TagSetManager):
# don't attempt to compare against unrelated types
return NotImplemented

return self.tags == other.tags

def __repr__(self):
return f"{self.__class__} tags={[tag for tag in self.tags]}"

Expand All @@ -170,6 +177,15 @@ def from_dict(tags: list[dict]) -> TagSetManager:
tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.REJECT))
return TagSetManager(tags=tag_list)

def to_dict(self) -> dict:
result_dict = {
'require': [tag.value for tag in self.tags if tag.tag_type == TagType.REQUIRE],
'prefer': [tag.value for tag in self.tags if tag.tag_type == TagType.PREFER],
'accept': [tag.value for tag in self.tags if tag.tag_type == TagType.ACCEPT],
'reject': [tag.value for tag in self.tags if tag.tag_type == TagType.REJECT]
}
return result_dict


class Entity(object):

Expand Down Expand Up @@ -280,6 +296,31 @@ def __repr__(self):
f"tags={self.tpv_tags}, rank={self.rank[:10] if self.rank else ''}, inherits={self.inherits}, "\
f"context={self.context}"

def __eq__(self, other):
if not isinstance(other, self.__class__):
# don't attempt to compare against unrelated types
return NotImplemented

return (
self.id == other.id and
self.abstract == other.abstract and
self.cores == other.cores and
self.mem == other.mem and
self.gpus == other.gpus and
self.min_cores == other.min_cores and
self.min_mem == other.min_mem and
self.min_gpus == other.min_gpus and
self.max_cores == other.max_cores and
self.max_mem == other.max_mem and
self.max_gpus == other.max_gpus and
self.env == other.env and
self.params == other.params and
self.resubmit == other.resubmit and
self.tpv_tags == other.tpv_tags and
self.inherits == other.inherits and
self.context == other.context
)

def merge_env_list(self, original, replace):
for i, original_elem in enumerate(original):
for j, replace_elem in enumerate(replace):
Expand Down Expand Up @@ -422,6 +463,28 @@ def rank_destinations(self, destinations, context):
log.debug(f"Ranking destinations: {destinations} for entity: {self} using default ranker")
return sorted(destinations, key=lambda d: d.score(self), reverse=True)

def to_dict(self):
dict_obj = {
'id': self.id,
'abstract': self.abstract,
'cores': self.cores,
'mem': self.mem,
'gpus': self.gpus,
'min_cores': self.min_cores,
'min_mem': self.min_mem,
'min_gpus': self.min_gpus,
'max_cores': self.max_cores,
'max_mem': self.max_mem,
'max_gpus': self.max_gpus,
'env': self.env,
'params': self.params,
'resubmit': self.resubmit,
'scheduling': self.tpv_tags.to_dict(),
'inherits': self.inherits,
'context': self.context
}
return dict_obj


class EntityWithRules(Entity):

Expand Down Expand Up @@ -472,6 +535,11 @@ def from_dict(cls: type, loader, entity_dict):
rules=entity_dict.get('rules')
)

def to_dict(self):
dict_obj = super().to_dict()
dict_obj['rules'] = [rule.to_dict() for rule in self.rules.values()]
return dict_obj

def override(self, entity):
new_entity = super().override(entity)
new_entity.rules = copy.deepcopy(entity.rules)
Expand Down Expand Up @@ -504,6 +572,11 @@ def evaluate(self, context):
def __repr__(self):
return super().__repr__() + f", rules={self.rules}"

def __eq__(self, other):
return super().__eq__(other) and (
self.rules == other.rules
)


class Tool(EntityWithRules):

Expand Down Expand Up @@ -578,6 +651,38 @@ def from_dict(loader, entity_dict):
handler_tags=entity_dict.get('tags')
)

def to_dict(self):
dict_obj = super().to_dict()
dict_obj['runner'] = self.runner
dict_obj['destination_name_override'] = self.dest_name
dict_obj['min_accepted_cores'] = self.min_accepted_cores
dict_obj['min_accepted_mem'] = self.min_accepted_mem
dict_obj['min_accepted_gpus'] = self.min_accepted_gpus
dict_obj['max_accepted_cores'] = self.max_accepted_cores
dict_obj['max_accepted_mem'] = self.max_accepted_mem
dict_obj['max_accepted_gpus'] = self.max_accepted_gpus
dict_obj['scheduling'] = self.tpv_dest_tags.to_dict()
dict_obj['tags'] = self.handler_tags
return dict_obj

def __eq__(self, other):
if not isinstance(other, Destination):
# don't attempt to compare against unrelated types
return NotImplemented

return super().__eq__(other) and (
self.runner == other.runner and
self.dest_name == other.dest_name and
self.min_accepted_cores == other.min_accepted_cores and
self.min_accepted_mem == other.min_accepted_mem and
self.min_accepted_gpus == other.min_accepted_gpus and
self.max_accepted_cores == other.max_accepted_cores and
self.max_accepted_mem == other.max_accepted_mem and
self.max_accepted_gpus == other.max_accepted_gpus and
self.tpv_dest_tags == other.tpv_dest_tags and
self.handler_tags == other.handler_tags
)

def __repr__(self):
return f"runner={self.runner}, dest_name={self.dest_name}, min_accepted_cores={self.min_accepted_cores}, "\
f"min_accepted_mem={self.min_accepted_mem}, min_accepted_gpus={self.min_accepted_gpus}, "\
Expand Down Expand Up @@ -726,6 +831,13 @@ def from_dict(loader, entity_dict):
fail=entity_dict.get('fail')
)

def to_dict(self):
dict_obj = super().to_dict()
dict_obj['if'] = self.match
dict_obj['execute'] = self.execute
dict_obj['fail'] = self.fail
return dict_obj

def override(self, entity):
new_entity = super().override(entity)
new_entity.match = self.match if self.match is not None else getattr(entity, 'match', None)
Expand Down
1 change: 0 additions & 1 deletion tpv/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def validate_entities(self, entity_class: type, entity_list: dict) -> dict:
if not entity_dict:
entity_dict = {}
entity_dict['id'] = entity_id
entity_class.from_dict(self, entity_dict)
validated[entity_id] = entity_class.from_dict(self, entity_dict)
except Exception:
log.exception(f"Could not load entity of type: {entity_class} with data: {entity_dict}")
Expand Down