Skip to content

Commit

Permalink
feat(framework) Add logstream functionality to flwr run (#3613)
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng authored Sep 26, 2024
1 parent 83cd4ba commit a636a34
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
39 changes: 23 additions & 16 deletions src/py/flwr/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@
CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)


def start_stream(
run_id: int, channel: grpc.Channel, refresh_period: int = CONN_REFRESH_PERIOD
) -> None:
"""Start log streaming for a given run ID."""
try:
while True:
logger(INFO, "Starting logstream for run_id `%s`", run_id)
stream_logs(run_id, channel, refresh_period)
time.sleep(2)
logger(DEBUG, "Reconnecting to logstream")
except KeyboardInterrupt:
logger(INFO, "Exiting logstream")
except grpc.RpcError as e:
# pylint: disable=E1101
if e.code() == grpc.StatusCode.NOT_FOUND:
logger(ERROR, "Invalid run_id `%s`, exiting", run_id)
if e.code() == grpc.StatusCode.CANCELLED:
pass
finally:
channel.close()


def stream_logs(run_id: int, channel: grpc.Channel, duration: int) -> None:
"""Stream logs from the beginning of a run with connection refresh."""
start_time = time.time()
Expand Down Expand Up @@ -206,22 +228,7 @@ def _log_with_superexec(
channel.subscribe(on_channel_state_change)

if stream:
try:
while True:
logger(INFO, "Starting logstream for run_id `%s`", run_id)
stream_logs(run_id, channel, CONN_REFRESH_PERIOD)
time.sleep(2)
logger(DEBUG, "Reconnecting to logstream")
except KeyboardInterrupt:
logger(INFO, "Exiting logstream")
except grpc.RpcError as e:
# pylint: disable=E1101
if e.code() == grpc.StatusCode.NOT_FOUND:
logger(ERROR, "Invalid run_id `%s`, exiting", run_id)
if e.code() == grpc.StatusCode.CANCELLED:
pass
finally:
channel.close()
start_stream(run_id, channel, CONN_REFRESH_PERIOD)
else:
logger(INFO, "Printing logstream for run_id `%s`", run_id)
print_logs(run_id, channel, timeout=5)
18 changes: 17 additions & 1 deletion src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub

from ..log import start_stream

CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)


def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
Expand Down Expand Up @@ -62,6 +66,14 @@ def run(
"inside the `pyproject.toml` in order to be properly overriden.",
),
] = None,
stream: Annotated[
bool,
typer.Option(
"--stream",
help="Use `--stream` with `flwr run` to display logs;\n "
"logs are not streamed by default.",
),
] = False,
) -> None:
"""Run Flower App."""
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
Expand Down Expand Up @@ -117,7 +129,7 @@ def run(
raise typer.Exit(code=1)

if "address" in federation_config:
_run_with_superexec(app, federation_config, config_overrides)
_run_with_superexec(app, federation_config, config_overrides, stream)
else:
_run_without_superexec(app, federation_config, config_overrides, federation)

Expand All @@ -126,6 +138,7 @@ def _run_with_superexec(
app: Path,
federation_config: dict[str, Any],
config_overrides: Optional[list[str]],
stream: bool,
) -> None:

insecure_str = federation_config.get("insecure")
Expand Down Expand Up @@ -183,6 +196,9 @@ def _run_with_superexec(
fab_path.unlink()
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)

if stream:
start_stream(res.run_id, channel, CONN_REFRESH_PERIOD)


def _run_without_superexec(
app: Optional[Path],
Expand Down

0 comments on commit a636a34

Please sign in to comment.