Skip to content

Commit

Permalink
Add format command and reformat test code
Browse files Browse the repository at this point in the history
  • Loading branch information
nuwang committed Aug 25, 2024
1 parent 9759507 commit af57588
Show file tree
Hide file tree
Showing 17 changed files with 1,949 additions and 756 deletions.
47 changes: 30 additions & 17 deletions tests/test_entity.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,72 @@
import os
import unittest
from tpv.rules import gateway
from tpv.core.entities import Destination
from tpv.core.entities import Tool
from tpv.core.loader import TPVConfigLoader

from tpv.commands.test import mock_galaxy
from tpv.core.entities import Destination, Tool
from tpv.core.loader import TPVConfigLoader
from tpv.rules import gateway


class TestEntity(unittest.TestCase):

@staticmethod
def _map_to_destination(app, job, tool, user):
tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml')
tpv_config = os.path.join(
os.path.dirname(__file__), "fixtures/mapping-rule-argument-based.yml"
)
gateway.ACTIVE_DESTINATION_MAPPER = None
return gateway.map_tool_to_destination(app, job, tool, user, tpv_config_files=[tpv_config])
return gateway.map_tool_to_destination(
app, job, tool, user, tpv_config_files=[tpv_config]
)

# issue: https://github.com/galaxyproject/total-perspective-vortex/issues/53
def test_all_entities_refer_to_same_loader(self):
app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml'))
app = mock_galaxy.App(
job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml")
)
job = mock_galaxy.Job()

tool = mock_galaxy.Tool('bwa')
user = mock_galaxy.User('ford', '[email protected]')
tool = mock_galaxy.Tool("bwa")
user = mock_galaxy.User("ford", "[email protected]")

# just map something so the ACTIVE_DESTINATION_MAPPER is populated
self._map_to_destination(app, job, tool, user)

# get the original loader
original_loader = gateway.ACTIVE_DESTINATION_MAPPER.loader

context = {
'app': app,
'job': job
}
context = {"app": app, "job": job}
# make sure we are still referring to the same loader after evaluation
evaluated_entity = gateway.ACTIVE_DESTINATION_MAPPER.match_combine_evaluate_entities(context, tool, user)
evaluated_entity = (
gateway.ACTIVE_DESTINATION_MAPPER.match_combine_evaluate_entities(
context, tool, user
)
)
assert evaluated_entity.evaluator == original_loader
for rule in evaluated_entity.rules:
assert rule.evaluator == original_loader

def test_destination_to_dict(self):
tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml')
tpv_config = os.path.join(
os.path.dirname(__file__), "fixtures/mapping-rule-argument-based.yml"
)
loader = TPVConfigLoader.from_url_or_path(tpv_config)

# create a destination
destination = loader.config.destinations["k8s_environment"]
# serialize the destination
serialized_destination = destination.dict()
# deserialize the same destination
deserialized_destination = Destination(evaluator=loader, **serialized_destination)
deserialized_destination = Destination(
evaluator=loader, **serialized_destination
)
# make sure the deserialized destination is the same as the original
self.assertEqual(deserialized_destination, destination)

def test_tool_to_dict(self):
tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml')
tpv_config = os.path.join(
os.path.dirname(__file__), "fixtures/mapping-rule-argument-based.yml"
)
loader = TPVConfigLoader.from_url_or_path(tpv_config)

# create a tool
Expand Down
11 changes: 8 additions & 3 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
"""Unit tests module for the helper functions"""

import unittest

from tpv.commands.test import mock_galaxy
from tpv.core.helpers import get_dataset_attributes


class TestHelpers(unittest.TestCase):
"""Tests for helper functions"""

def test_get_dataset_attributes(self):
"""Test that the function returns a dictionary with the correct attributes"""
job = mock_galaxy.Job()
job.add_input_dataset(
mock_galaxy.DatasetAssociation(
"test",
mock_galaxy.Dataset("test.txt", file_size=7*1024**3, object_store_id="files1")
)
mock_galaxy.Dataset(
"test.txt", file_size=7 * 1024**3, object_store_id="files1"
),
)
)
dataset_attributes = get_dataset_attributes(job.input_datasets)
expected_result = {0: {'object_store_id': 'files1', 'size': 7*1024**3}}
expected_result = {0: {"object_store_id": "files1", "size": 7 * 1024**3}}
self.assertEqual(dataset_attributes, expected_result)
59 changes: 38 additions & 21 deletions tests/test_mapper_basic.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,83 @@
import os
import re
import unittest
from tpv.rules import gateway
from tpv.commands.test import mock_galaxy

from galaxy.jobs.mapper import JobMappingException

from tpv.commands.test import mock_galaxy
from tpv.rules import gateway


class TestMapperBasic(unittest.TestCase):

@staticmethod
def _map_to_destination(tool, tpv_config_path=None):
galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml'))
galaxy_app = mock_galaxy.App(
job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml")
)
job = mock_galaxy.Job()
user = mock_galaxy.User('gargravarr', '[email protected]')
tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__),
'fixtures/mapping-basic.yml')
user = mock_galaxy.User("gargravarr", "[email protected]")
tpv_config = tpv_config_path or os.path.join(
os.path.dirname(__file__), "fixtures/mapping-basic.yml"
)
gateway.ACTIVE_DESTINATION_MAPPER = None
return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config])
return gateway.map_tool_to_destination(
galaxy_app, job, tool, user, tpv_config_files=[tpv_config]
)

def test_map_default_tool(self):
tool = mock_galaxy.Tool('sometool')
tool = mock_galaxy.Tool("sometool")
destination = self._map_to_destination(tool)
self.assertEqual(destination.id, "local")

def test_map_overridden_tool(self):
tool = mock_galaxy.Tool('bwa')
tool = mock_galaxy.Tool("bwa")
destination = self._map_to_destination(tool)
self.assertEqual(destination.id, "k8s_environment")

def test_map_unschedulable_tool(self):
tool = mock_galaxy.Tool('unschedulable_tool')
with self.assertRaisesRegex(JobMappingException, "No destinations are available to fulfill request"):
tool = mock_galaxy.Tool("unschedulable_tool")
with self.assertRaisesRegex(
JobMappingException, "No destinations are available to fulfill request"
):
self._map_to_destination(tool)

def test_map_invalidly_tagged_tool(self):
tool = mock_galaxy.Tool('invalidly_tagged_tool')
config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-invalid-tags.yml')
with self.assertRaisesRegex(Exception, r"Duplicate tags found: 'general' in \['require', 'reject'\]"):
tool = mock_galaxy.Tool("invalidly_tagged_tool")
config = os.path.join(
os.path.dirname(__file__), "fixtures/mapping-invalid-tags.yml"
)
with self.assertRaisesRegex(
Exception, r"Duplicate tags found: 'general' in \['require', 'reject'\]"
):
self._map_to_destination(tool, tpv_config_path=config)

def test_map_tool_by_regex(self):
tool = mock_galaxy.Tool('regex_tool_test')
tool = mock_galaxy.Tool("regex_tool_test")
destination = self._map_to_destination(tool)
self.assertEqual(destination.id, "k8s_environment")

def test_map_tool_by_regex_mismatch(self):
tool = mock_galaxy.Tool('regex_t_test')
tool = mock_galaxy.Tool("regex_t_test")
destination = self._map_to_destination(tool)
self.assertEqual(destination.id, "local")

def test_map_tool_with_invalid_regex(self):
tool = mock_galaxy.Tool('sometool')
config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-invalid-regex.yml')
tool = mock_galaxy.Tool("sometool")
config = os.path.join(
os.path.dirname(__file__), "fixtures/mapping-invalid-regex.yml"
)
with self.assertRaisesRegex(re.error, "bad escape"):
self._map_to_destination(tool, tpv_config_path=config)

def test_map_abstract_tool_should_fail(self):
tool = mock_galaxy.Tool('my_abstract_tool')
with self.assertRaisesRegex(JobMappingException, "This entity is abstract and cannot be mapped"):
tool = mock_galaxy.Tool("my_abstract_tool")
with self.assertRaisesRegex(
JobMappingException, "This entity is abstract and cannot be mapped"
):
self._map_to_destination(tool)

def test_map_concrete_descendant_should_succeed(self):
tool = mock_galaxy.Tool('my_concrete_tool')
tool = mock_galaxy.Tool("my_concrete_tool")
destination = self._map_to_destination(tool)
self.assertEqual(destination.id, "local")
90 changes: 67 additions & 23 deletions tests/test_mapper_context.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,100 @@
import os
import unittest
from tpv.rules import gateway

from tpv.commands.test import mock_galaxy
from tpv.rules import gateway


class TestMapperContext(unittest.TestCase):

@staticmethod
def _map_to_destination(tool, user, datasets, tpv_config_path=None):
galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml'))
galaxy_app = mock_galaxy.App(
job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml")
)
job = mock_galaxy.Job()
for d in datasets:
job.add_input_dataset(d)
tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__),
'fixtures/mapping-context.yml')
tpv_config = tpv_config_path or os.path.join(
os.path.dirname(__file__), "fixtures/mapping-context.yml"
)
gateway.ACTIVE_DESTINATION_MAPPER = None
return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config])
return gateway.map_tool_to_destination(
galaxy_app, job, tool, user, tpv_config_files=[tpv_config]
)

def test_map_context_default_overrides_global(self):
tool = mock_galaxy.Tool('trinity')
user = mock_galaxy.User('gargravarr', '[email protected]')
datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))]
tool = mock_galaxy.Tool("trinity")
user = mock_galaxy.User("gargravarr", "[email protected]")
datasets = [
mock_galaxy.DatasetAssociation(
"test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3)
)
]

destination = self._map_to_destination(tool, user, datasets)
self.assertEqual(destination.id, "local")
self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['3'])
self.assertEqual(destination.params['native_spec'], '--mem 9 --cores 3 --gpus 3')
self.assertEqual(
[
env["value"]
for env in destination.env
if env["name"] == "TEST_JOB_SLOTS"
],
["3"],
)
self.assertEqual(
destination.params["native_spec"], "--mem 9 --cores 3 --gpus 3"
)

def test_map_tool_overrides_default(self):
tool = mock_galaxy.Tool('bwa')
user = mock_galaxy.User('gargravarr', '[email protected]')
datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))]
tool = mock_galaxy.Tool("bwa")
user = mock_galaxy.User("gargravarr", "[email protected]")
datasets = [
mock_galaxy.DatasetAssociation(
"test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3)
)
]

destination = self._map_to_destination(tool, user, datasets)
self.assertEqual(destination.id, "k8s_environment")
self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['5'])
self.assertEqual(destination.params['native_spec'], '--mem 15 --cores 5 --gpus 4')
self.assertEqual(
[
env["value"]
for env in destination.env
if env["name"] == "TEST_JOB_SLOTS"
],
["5"],
)
self.assertEqual(
destination.params["native_spec"], "--mem 15 --cores 5 --gpus 4"
)

def test_context_variable_overridden_in_rule(self):
# test that job will not fail with 40GB input size because large_input_size has been set to 60
tool = mock_galaxy.Tool('bwa')
user = mock_galaxy.User('gargravarr', '[email protected]')
datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=40*1024**3))]
tool = mock_galaxy.Tool("bwa")
user = mock_galaxy.User("gargravarr", "[email protected]")
datasets = [
mock_galaxy.DatasetAssociation(
"test", mock_galaxy.Dataset("test.txt", file_size=40 * 1024**3)
)
]

destination = self._map_to_destination(tool, user, datasets)
self.assertEqual(destination.params['native_spec'], '--mem 15 --cores 5 --gpus 2')
self.assertEqual(
destination.params["native_spec"], "--mem 15 --cores 5 --gpus 2"
)

def test_context_variable_defined_for_tool_in_rule(self):
# test that context variable set for tool entity but not set in ancestor entities is defined
tool = mock_galaxy.Tool('canu')
user = mock_galaxy.User('gargravarr', '[email protected]')
datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=3*1024**3))]
tool = mock_galaxy.Tool("canu")
user = mock_galaxy.User("gargravarr", "[email protected]")
datasets = [
mock_galaxy.DatasetAssociation(
"test", mock_galaxy.Dataset("test.txt", file_size=3 * 1024**3)
)
]

destination = self._map_to_destination(tool, user, datasets)
self.assertEqual(destination.params['native_spec'], '--mem 9 --cores 3 --gpus 1')
self.assertEqual(
destination.params["native_spec"], "--mem 9 --cores 3 --gpus 1"
)
Loading

0 comments on commit af57588

Please sign in to comment.