diff --git a/yt/python/yt/cli/yt_binary.py b/yt/python/yt/cli/yt_binary.py index bc942b9d1..552fe280e 100755 --- a/yt/python/yt/cli/yt_binary.py +++ b/yt/python/yt/cli/yt_binary.py @@ -17,7 +17,7 @@ from yt.wrapper.default_config import get_default_config, RemotePatchableValueBase from yt.wrapper.admin_commands import add_switch_leader_parser from yt.wrapper.dirtable_commands import add_dirtable_parsers -from yt.wrapper.flow_commands import get_controller_logs +from yt.wrapper.flow_commands import get_controller_logs, wait_pipeline_state from yt.wrapper.spec_builders import ( MapSpecBuilder, ReduceSpecBuilder, MapReduceSpecBuilder, EraseSpecBuilder, MergeSpecBuilder, SortSpecBuilder, JoinReduceSpecBuilder, RemoteCopySpecBuilder, @@ -2453,22 +2453,41 @@ def add_flow_parser(root_subparsers): add_flow_show_logs_parser(add_flow_subparser) +def wait_pipeline_change(operation, state): + @copy_docstring_from(get_controller_logs) + def wrapper(**kwargs): + sync = kwargs.pop("sync") + + operation(**kwargs) + + if sync: + wait_pipeline_state(target_state=state, **kwargs) + + return wrapper + + def add_flow_start_pipeline_parser(add_parser): - parser = add_parser("start-pipeline", yt.start_pipeline, + parser = add_parser("start-pipeline", wait_pipeline_change(yt.start_pipeline, "working"), help="Start YT Flow pipeline") add_ypath_argument(parser, "pipeline_path", hybrid=True) + parser.add_argument("--sync", action="store_true", + help="Wait for the pipeline to start") def add_flow_stop_pipeline_parser(add_parser): - parser = add_parser("stop-pipeline", yt.stop_pipeline, + parser = add_parser("stop-pipeline", wait_pipeline_change(yt.stop_pipeline, "stopped"), help="Stop YT Flow pipeline") add_ypath_argument(parser, "pipeline_path", hybrid=True) + parser.add_argument("--sync", action="store_true", + help="Wait for the pipeline to stop") def add_flow_pause_pipeline_parser(add_parser): - parser = add_parser("pause-pipeline", yt.pause_pipeline, + parser = add_parser("pause-pipeline", wait_pipeline_change(yt.pause_pipeline, "paused"), help="Pause YT Flow pipeline") add_ypath_argument(parser, "pipeline_path", hybrid=True) + parser.add_argument("--sync", action="store_true", + help="Wait for the pipeline to pause") def add_flow_get_pipeline_spec_parser(add_parser): diff --git a/yt/python/yt/wrapper/flow_commands.py b/yt/python/yt/wrapper/flow_commands.py index 3e5504b08..22c9800be 100644 --- a/yt/python/yt/wrapper/flow_commands.py +++ b/yt/python/yt/wrapper/flow_commands.py @@ -3,6 +3,15 @@ from .dynamic_table_commands import select_rows from .ypath import YPath +from yt.wrapper.common import YtError + +from datetime import datetime, timedelta + +import yt.logger as logger + +import enum +import time + def start_pipeline(pipeline_path, client=None): """Start YT Flow pipeline. @@ -183,6 +192,61 @@ def get_pipeline_state(pipeline_path, client=None): client=client) +class PipelineState(str, enum.Enum): + Unknown = "unknown" + Stopped = "stopped" + Paused = "paused" + Working = "working" + Draining = "draining" + Pausing = "pausing" + Completed = "completed" + + +def wait_pipeline_state(target_state, pipeline_path, client=None, timeout=600): + if target_state == PipelineState.Completed: + target_states = {PipelineState.Completed, } + elif target_state == PipelineState.Working: + target_states = {PipelineState.Completed, PipelineState.Working} + elif target_state == PipelineState.Stopped: + target_states = {PipelineState.Completed, PipelineState.Stopped} + elif target_state == PipelineState.Draining: + target_states = {PipelineState.Completed, PipelineState.Stopped, PipelineState.Draining} + elif target_state == PipelineState.Paused: + target_states = {PipelineState.Completed, PipelineState.Stopped, PipelineState.Paused} + elif target_state == PipelineState.Pausing: + target_states = {PipelineState.Completed, PipelineState.Stopped, PipelineState.Paused, PipelineState.Pausing} + else: + logger.warning("Unknown pipeline state %s", target_state) + return + + invalid_state_transitions = { + PipelineState.Stopped: {PipelineState.Paused, }, + } + + deadline = datetime.now() + timedelta(seconds=timeout) + + while True: + if datetime.now() > deadline: + raise YtError("Wait time out", attributes={"timeout": timeout}) + + current_state = get_pipeline_state(pipeline_path=pipeline_path, client=client) + + if current_state in target_states: + logger.info("Waiting finished (current state: %s, target state: %s)", + current_state, target_state) + return + + if current_state in invalid_state_transitions.get(target_state, []): + raise YtError("Invalid state transition", attributes={ + "current_state": current_state, + "target_state": target_state}) + + logger.info("Still waiting (current state: %s, target state: %s)", + current_state, target_state) + + time.sleep(1) + + def get_flow_view(pipeline_path, view_path=None, format=None, client=None): """Get YT Flow flow view