Skip to content
This repository has been archived by the owner on May 11, 2021. It is now read-only.

Commit

Permalink
Merge pull request #22 from vaaaaanquish/specified_task_name
Browse files Browse the repository at this point in the history
override _get_input_targets
  • Loading branch information
vaaaaanquish authored Jan 12, 2021
2 parents e5fd13f + e6de87f commit ae7a44a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
9 changes: 9 additions & 0 deletions gokart_pipeliner/instantiation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,14 @@ def requires(self):
def requires(cls):
return {t: getattr(cls, t) for t in task_parameters}

def _get_input_targets(cls, target):
"""For task name may not be specified."""
if target is None:
return cls.input()[list(cls.input().keys())[0]]
if isinstance(target, str):
return cls.input()[target]
return target

task.requires = requires
task._get_input_targets = _get_input_targets
return task
11 changes: 11 additions & 0 deletions test/unittest/test_instatiation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,14 @@ def test_override_requires(self):
task = InstantiationTask.override_requires(task, ['target'])
output = task.requires(task)
self.assertDictEqual(output, {'target': task.target})

def test_override_requires_get_input_targets(self):
task = MockGokartTargetTask
task.input = lambda: {'target': 'foo'}
task = InstantiationTask.override_requires(task, ['target'])

output = task._get_input_targets(task, 'target')
self.assertEqual(output, 'foo')

output = task._get_input_targets(task, None)
self.assertEqual(output, 'foo')

0 comments on commit ae7a44a

Please sign in to comment.