Skip to content

Commit

Permalink
Rename disable_client_auth to require_client_auth with flipped de…
Browse files Browse the repository at this point in the history
…fault

Signed-off-by: Shah, Karan <[email protected]>
  • Loading branch information
MasterSkepticista committed Nov 21, 2024
1 parent 980dab8 commit bede7a4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 32 deletions.
2 changes: 1 addition & 1 deletion openfl-workspace/workspace/plan/defaults/network.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ settings:
hash_salt : auto
use_tls : True
client_reconnect_interval : 5
disable_client_auth : False
require_client_auth : True
cert_folder : cert
31 changes: 14 additions & 17 deletions openfl/transport/grpc/aggregator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,11 @@ class AggregatorGRPCClient:
Attributes:
uri (str): The URI of the aggregator.
use_tls (bool): Whether to use TLS for the connection.
disable_client_auth (bool): Whether to disable client-side
authentication.
root_certificate (str): The path to the root certificate for the TLS
connection.
certificate (str): The path to the client's certificate for the TLS
connection.
private_key (str): The path to the client's private key for the TLS
connection.
require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS.
Ignored if `use_tls=False`.
root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`.
certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`.
private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`.
aggregator_uuid (str): The UUID of the aggregator.
federation_uuid (str): The UUID of the federation.
single_col_cert_common_name (str): The common name on the
Expand All @@ -189,7 +186,7 @@ def __init__(
self,
agg_addr,
agg_port,
disable_client_auth,
require_client_auth,
root_certificate,
certificate,
private_key,
Expand All @@ -206,7 +203,7 @@ def __init__(
agg_addr (str): The address of the aggregator.
agg_port (int): The port of the aggregator.
use_tls (bool): Whether to use TLS for the connection.
disable_client_auth (bool): Whether to disable client-side
require_client_auth (bool): Whether to enable client-side
authentication.
root_certificate (str): The path to the root certificate for the
TLS connection.
Expand All @@ -222,7 +219,7 @@ def __init__(
"""
self.uri = f"{agg_addr}:{agg_port}"
self.use_tls = use_tls
self.disable_client_auth = disable_client_auth
self.require_client_auth = require_client_auth
self.root_certificate = root_certificate
self.certificate = certificate
self.private_key = private_key
Expand All @@ -236,7 +233,7 @@ def __init__(
self.channel = self.create_tls_channel(
self.uri,
self.root_certificate,
self.disable_client_auth,
self.require_client_auth,
self.certificate,
self.private_key,
)
Expand Down Expand Up @@ -278,7 +275,7 @@ def create_tls_channel(
self,
uri,
root_certificate,
disable_client_auth,
require_client_auth,
certificate,
private_key,
):
Expand All @@ -288,8 +285,8 @@ def create_tls_channel(
Args:
uri (str): The uniform resource identifier for the secure channel.
root_certificate (str): The Certificate Authority filename.
disable_client_auth (bool): True disables client-side
authentication (not recommended, throws warning to user).
require_client_auth (bool): True enables client-side
authentication.
certificate (str): The client certificate filename from the
collaborator (signed by the certificate authority).
private_key (str): The private key filename for the client
Expand All @@ -301,7 +298,7 @@ def create_tls_channel(
with open(root_certificate, "rb") as f:
root_certificate_b = f.read()

if disable_client_auth:
if not require_client_auth:
self.logger.warning("Client-side authentication is disabled.")
private_key_b = None
certificate_b = None
Expand Down Expand Up @@ -370,7 +367,7 @@ def reconnect(self):
self.channel = self.create_tls_channel(
self.uri,
self.root_certificate,
self.disable_client_auth,
self.require_client_auth,
self.certificate,
self.private_key,
)
Expand Down
25 changes: 11 additions & 14 deletions openfl/transport/grpc/aggregator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@ class AggregatorGRPCServer(aggregator_pb2_grpc.AggregatorServicer):
aggregator (Aggregator): The aggregator that this server is serving.
uri (str): The URI that the server is serving on.
use_tls (bool): Whether to use TLS for the connection.
disable_client_auth (bool): Whether to disable client-side
authentication.
root_certificate (str): The path to the root certificate for the TLS
connection.
certificate (str): The path to the server's certificate for the TLS
connection.
private_key (str): The path to the server's private key for the TLS
connection.
require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS.
Ignored if `use_tls=False`.
root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`.
certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`.
private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`.
server (grpc.Server): The gRPC server.
server_credentials (grpc.ServerCredentials): The server's credentials.
"""
Expand All @@ -46,7 +43,7 @@ def __init__(
aggregator,
agg_port,
use_tls=True,
disable_client_auth=False,
require_client_auth=True,
root_certificate=None,
certificate=None,
private_key=None,
Expand All @@ -60,7 +57,7 @@ def __init__(
serving.
agg_port (int): The port that the server is serving on.
use_tls (bool): Whether to use TLS for the connection.
disable_client_auth (bool): Whether to disable client-side
require_client_auth (bool): Whether to enable client-side
authentication.
root_certificate (str): The path to the root certificate for the
TLS connection.
Expand All @@ -74,7 +71,7 @@ def __init__(
self.aggregator = aggregator
self.uri = f"[::]:{agg_port}"
self.use_tls = use_tls
self.disable_client_auth = disable_client_auth
self.require_client_auth = require_client_auth
self.root_certificate = root_certificate
self.certificate = certificate
self.private_key = private_key
Expand All @@ -100,7 +97,7 @@ def validate_collaborator(self, request, context):
"""
if self.use_tls:
collaborator_common_name = request.header.sender
if self.disable_client_auth:
if not self.require_client_auth:
common_name = collaborator_common_name
else:
common_name = context.auth_context()["x509_common_name"][0].decode("utf-8")
Expand Down Expand Up @@ -324,13 +321,13 @@ def get_server(self):
with open(self.root_certificate, "rb") as f:
root_certificate_b = f.read()

if self.disable_client_auth:
if not self.require_client_auth:
self.logger.warning("Client-side authentication is disabled.")

self.server_credentials = ssl_server_credentials(
((private_key_b, certificate_b),),
root_certificates=root_certificate_b,
require_client_auth=not self.disable_client_auth,
require_client_auth=self.require_client_auth,
)

self.server.add_secure_port(self.uri, self.server_credentials)
Expand Down

0 comments on commit bede7a4

Please sign in to comment.