Skip to content

Commit

Permalink
feat(framework) Add user auth CLI integration (#4602)
Browse files Browse the repository at this point in the history
Co-authored-by: Heng Pan <[email protected]>
Co-authored-by: Chong Shen Ng <[email protected]>
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
4 people authored Dec 11, 2024
1 parent 1daee43 commit 59809ec
Show file tree
Hide file tree
Showing 13 changed files with 345 additions and 104 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pyyaml = "^6.0.2"
requests = "^2.31.0"
# Optional dependencies (Simulation Engine)
ray = { version = "==2.10.0", optional = true, python = ">=3.9,<3.12" }
# Optional dependencies (REST transport layer)
starlette = { version = "^0.31.0", optional = true }
uvicorn = { version = "^0.23.0", extras = ["standard"], optional = true }

Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .build import build
from .install import install
from .log import log
from .login import login
from .ls import ls
from .new import new
from .run import run
Expand All @@ -41,6 +42,7 @@
app.command()(log)
app.command()(ls)
app.command()(stop)
app.command()(login)

typer_click_object = get_command(app)

Expand Down
86 changes: 86 additions & 0 deletions src/py/flwr/cli/cli_user_auth_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower run interceptor."""


from typing import Any, Callable, Union

import grpc

from flwr.common.auth_plugin import CliAuthPlugin
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
StartRunRequest,
StreamLogsRequest,
)

Request = Union[
StartRunRequest,
StreamLogsRequest,
]


class CliUserAuthInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor # type: ignore
):
"""CLI interceptor for user authentication."""

def __init__(self, auth_plugin: CliAuthPlugin):
self.auth_plugin = auth_plugin

def _authenticated_call(
self,
continuation: Callable[[Any, Any], Any],
client_call_details: grpc.ClientCallDetails,
request: Request,
) -> grpc.Call:
"""Send and receive tokens via metadata."""
new_metadata = self.auth_plugin.write_tokens_to_metadata(
client_call_details.metadata or []
)

details = client_call_details._replace(metadata=new_metadata)

response = continuation(details, request)
if response.initial_metadata():
retrieved_metadata = dict(response.initial_metadata())
self.auth_plugin.store_tokens(retrieved_metadata)

return response

def intercept_unary_unary(
self,
continuation: Callable[[Any, Any], Any],
client_call_details: grpc.ClientCallDetails,
request: Request,
) -> grpc.Call:
"""Intercept a unary-unary call for user authentication.
This method intercepts a unary-unary RPC call initiated from the CLI and adds
the required authentication tokens to the RPC metadata.
"""
return self._authenticated_call(continuation, client_call_details, request)

def intercept_unary_stream(
self,
continuation: Callable[[Any, Any], Any],
client_call_details: grpc.ClientCallDetails,
request: Request,
) -> grpc.Call:
"""Intercept a unary-stream call for user authentication.
This method intercepts a unary-stream RPC call initiated from the CLI and adds
the required authentication tokens to the RPC metadata.
"""
return self._authenticated_call(continuation, client_call_details, request)
26 changes: 6 additions & 20 deletions src/py/flwr/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@

from flwr.cli.config_utils import (
load_and_validate,
validate_certificate_in_federation_config,
validate_federation_in_project_config,
validate_project_config,
)
from flwr.common.constant import CONN_RECONNECT_INTERVAL, CONN_REFRESH_PERIOD
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log as logger
from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub

from .utils import init_channel, try_obtain_cli_auth_plugin


def start_stream(
run_id: int, channel: grpc.Channel, refresh_period: int = CONN_REFRESH_PERIOD
Expand Down Expand Up @@ -126,11 +126,6 @@ def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
logger(DEBUG, "Channel closed")


def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
logger(DEBUG, channel_connectivity)


def log(
run_id: Annotated[
int,
Expand Down Expand Up @@ -171,27 +166,18 @@ def log(
)
raise typer.Exit(code=1)

_log_with_exec_api(app, federation_config, run_id, stream)
_log_with_exec_api(app, federation, federation_config, run_id, stream)


def _log_with_exec_api(
app: Path,
federation: str,
federation_config: dict[str, Any],
run_id: int,
stream: bool,
) -> None:

insecure, root_certificates_bytes = validate_certificate_in_federation_config(
app, federation_config
)
channel = create_channel(
server_address=federation_config["address"],
insecure=insecure,
root_certificates=root_certificates_bytes,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
interceptors=None,
)
channel.subscribe(on_channel_state_change)
auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
channel = init_channel(app, federation_config, auth_plugin)

if stream:
start_stream(run_id, channel, CONN_REFRESH_PERIOD)
Expand Down
21 changes: 21 additions & 0 deletions src/py/flwr/cli/login/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower command line interface `login` command."""

from .login import login as login

__all__ = [
"login",
]
90 changes: 90 additions & 0 deletions src/py/flwr/cli/login/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower command line interface `login` command."""

from pathlib import Path
from typing import Annotated, Optional

import typer

from flwr.cli.config_utils import (
load_and_validate,
validate_federation_in_project_config,
validate_project_config,
)
from flwr.common.constant import AUTH_TYPE
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
GetLoginDetailsRequest,
GetLoginDetailsResponse,
)
from flwr.proto.exec_pb2_grpc import ExecStub

from ..utils import init_channel, try_obtain_cli_auth_plugin


def login( # pylint: disable=R0914
app: Annotated[
Path,
typer.Argument(help="Path of the Flower App to run."),
] = Path("."),
federation: Annotated[
Optional[str],
typer.Argument(help="Name of the federation to login into."),
] = None,
) -> None:
"""Login to Flower SuperLink."""
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)

pyproject_path = app / "pyproject.toml" if app else None
config, errors, warnings = load_and_validate(path=pyproject_path)

config = validate_project_config(config, errors, warnings)
federation, federation_config = validate_federation_in_project_config(
federation, config
)

if "address" not in federation_config:
typer.secho(
"❌ `flwr login` currently works with a SuperLink. Ensure that the correct"
"SuperLink (Exec API) address is provided in `pyproject.toml`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

channel = init_channel(app, federation_config, None)
stub = ExecStub(channel)

login_request = GetLoginDetailsRequest()
login_response: GetLoginDetailsResponse = stub.GetLoginDetails(login_request)

# Get the auth plugin
auth_type = login_response.login_details.get(AUTH_TYPE)
auth_plugin = try_obtain_cli_auth_plugin(
app, federation, federation_config, auth_type
)
if auth_plugin is None:
typer.secho(
f'❌ Authentication type "{auth_type}" not found',
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

# Login
auth_config = auth_plugin.login(dict(login_response.login_details), stub)

# Store the tokens
auth_plugin.store_tokens(auth_config)
35 changes: 6 additions & 29 deletions src/py/flwr/cli/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
import io
import json
from datetime import datetime, timedelta
from logging import DEBUG
from pathlib import Path
from typing import Annotated, Any, Optional, Union
from typing import Annotated, Optional, Union

import grpc
import typer
from rich.console import Console
from rich.table import Table
Expand All @@ -31,14 +29,12 @@

from flwr.cli.config_utils import (
load_and_validate,
validate_certificate_in_federation_config,
validate_federation_in_project_config,
validate_project_config,
)
from flwr.common.constant import FAB_CONFIG_FILE, CliOutputFormat, SubStatus
from flwr.common.date import format_timedelta, isoformat8601_utc
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log, redirect_output, remove_emojis, restore_output
from flwr.common.logger import redirect_output, remove_emojis, restore_output
from flwr.common.serde import run_from_proto
from flwr.common.typing import Run
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
Expand All @@ -47,6 +43,8 @@
)
from flwr.proto.exec_pb2_grpc import ExecStub

from .utils import init_channel, try_obtain_cli_auth_plugin

_RunListType = tuple[int, str, str, str, str, str, str, str, str]


Expand Down Expand Up @@ -113,8 +111,8 @@ def ls( # pylint: disable=too-many-locals, too-many-branches
raise ValueError(
"The options '--runs' and '--run-id' are mutually exclusive."
)

channel = _init_channel(app, federation_config)
auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
channel = init_channel(app, federation_config, auth_plugin)
stub = ExecStub(channel)

# Display information about a specific run ID
Expand Down Expand Up @@ -154,27 +152,6 @@ def ls( # pylint: disable=too-many-locals, too-many-branches
captured_output.close()


def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
log(DEBUG, channel_connectivity)


def _init_channel(app: Path, federation_config: dict[str, Any]) -> grpc.Channel:
"""Initialize gRPC channel to the Exec API."""
insecure, root_certificates_bytes = validate_certificate_in_federation_config(
app, federation_config
)
channel = create_channel(
server_address=federation_config["address"],
insecure=insecure,
root_certificates=root_certificates_bytes,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
interceptors=None,
)
channel.subscribe(on_channel_state_change)
return channel


def _format_runs(run_dict: dict[int, Run], now_isoformat: str) -> list[_RunListType]:
"""Format runs to a list."""

Expand Down
Loading

0 comments on commit 59809ec

Please sign in to comment.