diff --git a/gokart_pipeliner/instantiation_task.py b/gokart_pipeliner/instantiation_task.py index 11fb2f9..f773196 100644 --- a/gokart_pipeliner/instantiation_task.py +++ b/gokart_pipeliner/instantiation_task.py @@ -67,5 +67,14 @@ def requires(self): def requires(cls): return {t: getattr(cls, t) for t in task_parameters} + def _get_input_targets(cls, target): + """For task name may not be specified.""" + if target is None: + return cls.input()[list(cls.input().keys())[0]] + if isinstance(target, str): + return cls.input()[target] + return target + task.requires = requires + task._get_input_targets = _get_input_targets return task diff --git a/test/unittest/test_instatiation_task.py b/test/unittest/test_instatiation_task.py index 888741c..cd7be50 100644 --- a/test/unittest/test_instatiation_task.py +++ b/test/unittest/test_instatiation_task.py @@ -90,3 +90,14 @@ def test_override_requires(self): task = InstantiationTask.override_requires(task, ['target']) output = task.requires(task) self.assertDictEqual(output, {'target': task.target}) + + def test_override_requires_get_input_targets(self): + task = MockGokartTargetTask + task.input = lambda: {'target': 'foo'} + task = InstantiationTask.override_requires(task, ['target']) + + output = task._get_input_targets(task, 'target') + self.assertEqual(output, 'foo') + + output = task._get_input_targets(task, None) + self.assertEqual(output, 'foo')