Skip to content

Commit

Permalink
Merge pull request #119 from pauldg/dest_to_dict_patch
Browse files Browse the repository at this point in the history
Add to to_dict functions to Destination and TagSetManager classes
  • Loading branch information
nuwang authored Jan 17, 2024
2 parents b8dd7d1 + 1e400f6 commit dc2a19a
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 1 deletion.
30 changes: 30 additions & 0 deletions tests/test_entity.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import unittest
from tpv.rules import gateway
from tpv.core.entities import Destination
from tpv.core.entities import Tag
from tpv.core.entities import TagType
from tpv.core.entities import Tool
from tpv.core.loader import TPVConfigLoader
from tpv.commands.test import mock_galaxy


Expand Down Expand Up @@ -38,6 +41,33 @@ def test_all_entities_refer_to_same_loader(self):
for rule in evaluated_entity.rules:
assert rule.loader == original_loader

def test_destination_to_dict(self):

tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml')
loader = TPVConfigLoader.from_url_or_path(tpv_config)

# create a destination
destination = loader.destinations["k8s_environment"]
# serialize the destination
serialized_destination = destination.to_dict()
# deserialize the same destination
deserialized_destination = Destination.from_dict(loader, serialized_destination)
# make sure the deserialized destination is the same as the original
self.assertEqual(deserialized_destination, destination)

def test_tool_to_dict(self):
tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml')
loader = TPVConfigLoader.from_url_or_path(tpv_config)

# create a tool
tool = loader.tools["limbo"]
# serialize the tool
serialized_destination = tool.to_dict()
# deserialize the same tool
deserialized_destination = Tool.from_dict(loader, serialized_destination)
# make sure the deserialized tool is the same as the original
self.assertEqual(deserialized_destination, tool)

def test_tag_equivalence(self):
tag1 = Tag("tag_name", "tag_value", TagType.REQUIRE)
tag2 = Tag("tag_name2", "tag_value", TagType.REQUIRE)
Expand Down
112 changes: 112 additions & 0 deletions tpv/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def score(self, other: TagSetManager) -> bool:
# penalize tags that don't exist in the other
- sum(int(tag.tag_type) for tag in self.tags if not other.contains_tag(tag)))

def __eq__(self, other):
if not isinstance(other, TagSetManager):
# don't attempt to compare against unrelated types
return NotImplemented

return self.tags == other.tags

def __repr__(self):
return f"{self.__class__} tags={[tag for tag in self.tags]}"

Expand All @@ -170,6 +177,15 @@ def from_dict(tags: list[dict]) -> TagSetManager:
tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.REJECT))
return TagSetManager(tags=tag_list)

def to_dict(self) -> dict:
result_dict = {
'require': [tag.value for tag in self.tags if tag.tag_type == TagType.REQUIRE],
'prefer': [tag.value for tag in self.tags if tag.tag_type == TagType.PREFER],
'accept': [tag.value for tag in self.tags if tag.tag_type == TagType.ACCEPT],
'reject': [tag.value for tag in self.tags if tag.tag_type == TagType.REJECT]
}
return result_dict


class Entity(object):

Expand Down Expand Up @@ -280,6 +296,31 @@ def __repr__(self):
f"tags={self.tpv_tags}, rank={self.rank[:10] if self.rank else ''}, inherits={self.inherits}, "\
f"context={self.context}"

def __eq__(self, other):
if not isinstance(other, self.__class__):
# don't attempt to compare against unrelated types
return NotImplemented

return (
self.id == other.id and
self.abstract == other.abstract and
self.cores == other.cores and
self.mem == other.mem and
self.gpus == other.gpus and
self.min_cores == other.min_cores and
self.min_mem == other.min_mem and
self.min_gpus == other.min_gpus and
self.max_cores == other.max_cores and
self.max_mem == other.max_mem and
self.max_gpus == other.max_gpus and
self.env == other.env and
self.params == other.params and
self.resubmit == other.resubmit and
self.tpv_tags == other.tpv_tags and
self.inherits == other.inherits and
self.context == other.context
)

def merge_env_list(self, original, replace):
for i, original_elem in enumerate(original):
for j, replace_elem in enumerate(replace):
Expand Down Expand Up @@ -422,6 +463,28 @@ def rank_destinations(self, destinations, context):
log.debug(f"Ranking destinations: {destinations} for entity: {self} using default ranker")
return sorted(destinations, key=lambda d: d.score(self), reverse=True)

def to_dict(self):
dict_obj = {
'id': self.id,
'abstract': self.abstract,
'cores': self.cores,
'mem': self.mem,
'gpus': self.gpus,
'min_cores': self.min_cores,
'min_mem': self.min_mem,
'min_gpus': self.min_gpus,
'max_cores': self.max_cores,
'max_mem': self.max_mem,
'max_gpus': self.max_gpus,
'env': self.env,
'params': self.params,
'resubmit': self.resubmit,
'scheduling': self.tpv_tags.to_dict(),
'inherits': self.inherits,
'context': self.context
}
return dict_obj


class EntityWithRules(Entity):

Expand Down Expand Up @@ -472,6 +535,11 @@ def from_dict(cls: type, loader, entity_dict):
rules=entity_dict.get('rules')
)

def to_dict(self):
dict_obj = super().to_dict()
dict_obj['rules'] = [rule.to_dict() for rule in self.rules.values()]
return dict_obj

def override(self, entity):
new_entity = super().override(entity)
new_entity.rules = copy.deepcopy(entity.rules)
Expand Down Expand Up @@ -504,6 +572,11 @@ def evaluate(self, context):
def __repr__(self):
return super().__repr__() + f", rules={self.rules}"

def __eq__(self, other):
return super().__eq__(other) and (
self.rules == other.rules
)


class Tool(EntityWithRules):

Expand Down Expand Up @@ -578,6 +651,38 @@ def from_dict(loader, entity_dict):
handler_tags=entity_dict.get('tags')
)

def to_dict(self):
dict_obj = super().to_dict()
dict_obj['runner'] = self.runner
dict_obj['destination_name_override'] = self.dest_name
dict_obj['min_accepted_cores'] = self.min_accepted_cores
dict_obj['min_accepted_mem'] = self.min_accepted_mem
dict_obj['min_accepted_gpus'] = self.min_accepted_gpus
dict_obj['max_accepted_cores'] = self.max_accepted_cores
dict_obj['max_accepted_mem'] = self.max_accepted_mem
dict_obj['max_accepted_gpus'] = self.max_accepted_gpus
dict_obj['scheduling'] = self.tpv_dest_tags.to_dict()
dict_obj['tags'] = self.handler_tags
return dict_obj

def __eq__(self, other):
if not isinstance(other, Destination):
# don't attempt to compare against unrelated types
return NotImplemented

return super().__eq__(other) and (
self.runner == other.runner and
self.dest_name == other.dest_name and
self.min_accepted_cores == other.min_accepted_cores and
self.min_accepted_mem == other.min_accepted_mem and
self.min_accepted_gpus == other.min_accepted_gpus and
self.max_accepted_cores == other.max_accepted_cores and
self.max_accepted_mem == other.max_accepted_mem and
self.max_accepted_gpus == other.max_accepted_gpus and
self.tpv_dest_tags == other.tpv_dest_tags and
self.handler_tags == other.handler_tags
)

def __repr__(self):
return f"runner={self.runner}, dest_name={self.dest_name}, min_accepted_cores={self.min_accepted_cores}, "\
f"min_accepted_mem={self.min_accepted_mem}, min_accepted_gpus={self.min_accepted_gpus}, "\
Expand Down Expand Up @@ -726,6 +831,13 @@ def from_dict(loader, entity_dict):
fail=entity_dict.get('fail')
)

def to_dict(self):
dict_obj = super().to_dict()
dict_obj['if'] = self.match
dict_obj['execute'] = self.execute
dict_obj['fail'] = self.fail
return dict_obj

def override(self, entity):
new_entity = super().override(entity)
new_entity.match = self.match if self.match is not None else getattr(entity, 'match', None)
Expand Down
1 change: 0 additions & 1 deletion tpv/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def validate_entities(self, entity_class: type, entity_list: dict) -> dict:
if not entity_dict:
entity_dict = {}
entity_dict['id'] = entity_id
entity_class.from_dict(self, entity_dict)
validated[entity_id] = entity_class.from_dict(self, entity_dict)
except Exception:
log.exception(f"Could not load entity of type: {entity_class} with data: {entity_dict}")
Expand Down

0 comments on commit dc2a19a

Please sign in to comment.