Skip to content

Commit

Permalink
feat(framework) Add prompt for flwr login when authentication fails (
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Dec 17, 2024
1 parent e9c3653 commit 6998874
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 13 deletions.
14 changes: 8 additions & 6 deletions src/py/flwr/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub

from .utils import init_channel, try_obtain_cli_auth_plugin
from .utils import init_channel, try_obtain_cli_auth_plugin, unauthenticated_exc_handler


def start_stream(
Expand Down Expand Up @@ -88,8 +88,9 @@ def stream_logs(
latest_timestamp = 0.0
res = None
try:
for res in stub.StreamLogs(req, timeout=duration):
print(res.log_output, end="")
with unauthenticated_exc_handler():
for res in stub.StreamLogs(req, timeout=duration):
print(res.log_output, end="")
except grpc.RpcError as e:
# pylint: disable=E1101
if e.code() != grpc.StatusCode.DEADLINE_EXCEEDED:
Expand All @@ -109,9 +110,10 @@ def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
try:
while True:
try:
# Enforce timeout for graceful exit
for res in stub.StreamLogs(req, timeout=timeout):
print(res.log_output)
with unauthenticated_exc_handler():
# Enforce timeout for graceful exit
for res in stub.StreamLogs(req, timeout=timeout):
print(res.log_output)
except grpc.RpcError as e:
# pylint: disable=E1101
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
Expand Down
8 changes: 5 additions & 3 deletions src/py/flwr/cli/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from flwr.proto.exec_pb2_grpc import ExecStub

from .utils import init_channel, try_obtain_cli_auth_plugin
from .utils import init_channel, try_obtain_cli_auth_plugin, unauthenticated_exc_handler

_RunListType = tuple[int, str, str, str, str, str, str, str, str]

Expand Down Expand Up @@ -295,7 +295,8 @@ def _list_runs(
output_format: str = CliOutputFormat.DEFAULT,
) -> None:
"""List all runs."""
res: ListRunsResponse = stub.ListRuns(ListRunsRequest())
with unauthenticated_exc_handler():
res: ListRunsResponse = stub.ListRuns(ListRunsRequest())
run_dict = {run_id: run_from_proto(proto) for run_id, proto in res.run_dict.items()}

formatted_runs = _format_runs(run_dict, res.now)
Expand All @@ -311,7 +312,8 @@ def _display_one_run(
output_format: str = CliOutputFormat.DEFAULT,
) -> None:
"""Display information about a specific run."""
res: ListRunsResponse = stub.ListRuns(ListRunsRequest(run_id=run_id))
with unauthenticated_exc_handler():
res: ListRunsResponse = stub.ListRuns(ListRunsRequest(run_id=run_id))
if not res.run_dict:
raise ValueError(f"Run ID {run_id} not found")

Expand Down
9 changes: 7 additions & 2 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
from flwr.proto.exec_pb2_grpc import ExecStub

from ..log import start_stream
from ..utils import init_channel, try_obtain_cli_auth_plugin
from ..utils import (
init_channel,
try_obtain_cli_auth_plugin,
unauthenticated_exc_handler,
)

CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)

Expand Down Expand Up @@ -166,7 +170,8 @@ def _run_with_exec_api(
override_config=user_config_to_proto(parse_config_args(config_overrides)),
federation_options=configs_record_to_proto(c_record),
)
res = stub.StartRun(req)
with unauthenticated_exc_handler():
res = stub.StartRun(req)

if res.HasField("run_id"):
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/cli/stop.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from flwr.proto.exec_pb2 import StopRunRequest, StopRunResponse # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub

from .utils import init_channel, try_obtain_cli_auth_plugin
from .utils import init_channel, try_obtain_cli_auth_plugin, unauthenticated_exc_handler


def stop( # pylint: disable=R0914
Expand Down Expand Up @@ -113,7 +113,8 @@ def stop( # pylint: disable=R0914

def _stop_run(stub: ExecStub, run_id: int, output_format: str) -> None:
"""Stop a run."""
response: StopRunResponse = stub.StopRun(request=StopRunRequest(run_id=run_id))
with unauthenticated_exc_handler():
response: StopRunResponse = stub.StopRun(request=StopRunRequest(run_id=run_id))
if response.success:
typer.secho(f"✅ Run {run_id} successfully stopped.", fg=typer.colors.GREEN)
if output_format == CliOutputFormat.JSON:
Expand Down
23 changes: 23 additions & 0 deletions src/py/flwr/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import hashlib
import json
import re
from collections.abc import Iterator
from contextlib import contextmanager
from logging import DEBUG
from pathlib import Path
from typing import Any, Callable, Optional, cast
Expand Down Expand Up @@ -231,3 +233,24 @@ def on_channel_state_change(channel_connectivity: str) -> None:
)
channel.subscribe(on_channel_state_change)
return channel


@contextmanager
def unauthenticated_exc_handler() -> Iterator[None]:
"""Context manager to handle gRPC UNAUTHENTICATED errors.
It catches grpc.RpcError exceptions with UNAUTHENTICATED status, informs the user,
and exits the application. All other exceptions will be allowed to escape.
"""
try:
yield
except grpc.RpcError as e:
if e.code() != grpc.StatusCode.UNAUTHENTICATED:
raise
typer.secho(
"❌ Authentication failed. Please run `flwr login`"
" to authenticate and try again.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1) from None

0 comments on commit 6998874

Please sign in to comment.