From a636a34b682c9e4b784ae82983296c15adafd3dd Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 26 Sep 2024 21:42:47 +0100 Subject: [PATCH] feat(framework) Add logstream functionality to `flwr run` (#3613) --- src/py/flwr/cli/log.py | 39 ++++++++++++++++++++++---------------- src/py/flwr/cli/run/run.py | 18 +++++++++++++++++- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/src/py/flwr/cli/log.py b/src/py/flwr/cli/log.py index cd4079c1c13..7199cefce4f 100644 --- a/src/py/flwr/cli/log.py +++ b/src/py/flwr/cli/log.py @@ -32,6 +32,28 @@ CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds) +def start_stream( + run_id: int, channel: grpc.Channel, refresh_period: int = CONN_REFRESH_PERIOD +) -> None: + """Start log streaming for a given run ID.""" + try: + while True: + logger(INFO, "Starting logstream for run_id `%s`", run_id) + stream_logs(run_id, channel, refresh_period) + time.sleep(2) + logger(DEBUG, "Reconnecting to logstream") + except KeyboardInterrupt: + logger(INFO, "Exiting logstream") + except grpc.RpcError as e: + # pylint: disable=E1101 + if e.code() == grpc.StatusCode.NOT_FOUND: + logger(ERROR, "Invalid run_id `%s`, exiting", run_id) + if e.code() == grpc.StatusCode.CANCELLED: + pass + finally: + channel.close() + + def stream_logs(run_id: int, channel: grpc.Channel, duration: int) -> None: """Stream logs from the beginning of a run with connection refresh.""" start_time = time.time() @@ -206,22 +228,7 @@ def _log_with_superexec( channel.subscribe(on_channel_state_change) if stream: - try: - while True: - logger(INFO, "Starting logstream for run_id `%s`", run_id) - stream_logs(run_id, channel, CONN_REFRESH_PERIOD) - time.sleep(2) - logger(DEBUG, "Reconnecting to logstream") - except KeyboardInterrupt: - logger(INFO, "Exiting logstream") - except grpc.RpcError as e: - # pylint: disable=E1101 - if e.code() == grpc.StatusCode.NOT_FOUND: - logger(ERROR, "Invalid run_id `%s`, exiting", run_id) - if e.code() == grpc.StatusCode.CANCELLED: - pass - finally: - channel.close() + start_stream(run_id, channel, CONN_REFRESH_PERIOD) else: logger(INFO, "Printing logstream for run_id `%s`", run_id) print_logs(run_id, channel, timeout=5) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 905055ac70c..2832af3aeba 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -34,6 +34,10 @@ from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611 from flwr.proto.exec_pb2_grpc import ExecStub +from ..log import start_stream + +CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds) + def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" @@ -62,6 +66,14 @@ def run( "inside the `pyproject.toml` in order to be properly overriden.", ), ] = None, + stream: Annotated[ + bool, + typer.Option( + "--stream", + help="Use `--stream` with `flwr run` to display logs;\n " + "logs are not streamed by default.", + ), + ] = False, ) -> None: """Run Flower App.""" typer.secho("Loading project configuration... ", fg=typer.colors.BLUE) @@ -117,7 +129,7 @@ def run( raise typer.Exit(code=1) if "address" in federation_config: - _run_with_superexec(app, federation_config, config_overrides) + _run_with_superexec(app, federation_config, config_overrides, stream) else: _run_without_superexec(app, federation_config, config_overrides, federation) @@ -126,6 +138,7 @@ def _run_with_superexec( app: Path, federation_config: dict[str, Any], config_overrides: Optional[list[str]], + stream: bool, ) -> None: insecure_str = federation_config.get("insecure") @@ -183,6 +196,9 @@ def _run_with_superexec( fab_path.unlink() typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN) + if stream: + start_stream(res.run_id, channel, CONN_REFRESH_PERIOD) + def _run_without_superexec( app: Optional[Path],