diff --git a/openfl-workspace/workspace/plan/defaults/network.yaml b/openfl-workspace/workspace/plan/defaults/network.yaml index 5b4477b20b..11e03c1890 100644 --- a/openfl-workspace/workspace/plan/defaults/network.yaml +++ b/openfl-workspace/workspace/plan/defaults/network.yaml @@ -5,5 +5,5 @@ settings: hash_salt : auto use_tls : True client_reconnect_interval : 5 - disable_client_auth : False + require_client_auth : True cert_folder : cert diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index 8bb19fb559..74805218ec 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -171,14 +171,11 @@ class AggregatorGRPCClient: Attributes: uri (str): The URI of the aggregator. use_tls (bool): Whether to use TLS for the connection. - disable_client_auth (bool): Whether to disable client-side - authentication. - root_certificate (str): The path to the root certificate for the TLS - connection. - certificate (str): The path to the client's certificate for the TLS - connection. - private_key (str): The path to the client's private key for the TLS - connection. + require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS. + Ignored if `use_tls=False`. + root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`. + certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`. + private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`. aggregator_uuid (str): The UUID of the aggregator. federation_uuid (str): The UUID of the federation. single_col_cert_common_name (str): The common name on the @@ -189,7 +186,7 @@ def __init__( self, agg_addr, agg_port, - disable_client_auth, + require_client_auth, root_certificate, certificate, private_key, @@ -206,7 +203,7 @@ def __init__( agg_addr (str): The address of the aggregator. agg_port (int): The port of the aggregator. use_tls (bool): Whether to use TLS for the connection. - disable_client_auth (bool): Whether to disable client-side + require_client_auth (bool): Whether to enable client-side authentication. root_certificate (str): The path to the root certificate for the TLS connection. @@ -222,7 +219,7 @@ def __init__( """ self.uri = f"{agg_addr}:{agg_port}" self.use_tls = use_tls - self.disable_client_auth = disable_client_auth + self.require_client_auth = require_client_auth self.root_certificate = root_certificate self.certificate = certificate self.private_key = private_key @@ -236,7 +233,7 @@ def __init__( self.channel = self.create_tls_channel( self.uri, self.root_certificate, - self.disable_client_auth, + self.require_client_auth, self.certificate, self.private_key, ) @@ -278,7 +275,7 @@ def create_tls_channel( self, uri, root_certificate, - disable_client_auth, + require_client_auth, certificate, private_key, ): @@ -288,8 +285,8 @@ def create_tls_channel( Args: uri (str): The uniform resource identifier for the secure channel. root_certificate (str): The Certificate Authority filename. - disable_client_auth (bool): True disables client-side - authentication (not recommended, throws warning to user). + require_client_auth (bool): True enables client-side + authentication. certificate (str): The client certificate filename from the collaborator (signed by the certificate authority). private_key (str): The private key filename for the client @@ -301,7 +298,7 @@ def create_tls_channel( with open(root_certificate, "rb") as f: root_certificate_b = f.read() - if disable_client_auth: + if not require_client_auth: self.logger.warning("Client-side authentication is disabled.") private_key_b = None certificate_b = None @@ -370,7 +367,7 @@ def reconnect(self): self.channel = self.create_tls_channel( self.uri, self.root_certificate, - self.disable_client_auth, + self.require_client_auth, self.certificate, self.private_key, ) diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index 7d4b5c6b77..31b5ebc1c9 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -29,14 +29,11 @@ class AggregatorGRPCServer(aggregator_pb2_grpc.AggregatorServicer): aggregator (Aggregator): The aggregator that this server is serving. uri (str): The URI that the server is serving on. use_tls (bool): Whether to use TLS for the connection. - disable_client_auth (bool): Whether to disable client-side - authentication. - root_certificate (str): The path to the root certificate for the TLS - connection. - certificate (str): The path to the server's certificate for the TLS - connection. - private_key (str): The path to the server's private key for the TLS - connection. + require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS. + Ignored if `use_tls=False`. + root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`. + certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`. + private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`. server (grpc.Server): The gRPC server. server_credentials (grpc.ServerCredentials): The server's credentials. """ @@ -46,7 +43,7 @@ def __init__( aggregator, agg_port, use_tls=True, - disable_client_auth=False, + require_client_auth=True, root_certificate=None, certificate=None, private_key=None, @@ -60,7 +57,7 @@ def __init__( serving. agg_port (int): The port that the server is serving on. use_tls (bool): Whether to use TLS for the connection. - disable_client_auth (bool): Whether to disable client-side + require_client_auth (bool): Whether to enable client-side authentication. root_certificate (str): The path to the root certificate for the TLS connection. @@ -74,7 +71,7 @@ def __init__( self.aggregator = aggregator self.uri = f"[::]:{agg_port}" self.use_tls = use_tls - self.disable_client_auth = disable_client_auth + self.require_client_auth = require_client_auth self.root_certificate = root_certificate self.certificate = certificate self.private_key = private_key @@ -100,7 +97,7 @@ def validate_collaborator(self, request, context): """ if self.use_tls: collaborator_common_name = request.header.sender - if self.disable_client_auth: + if not self.require_client_auth: common_name = collaborator_common_name else: common_name = context.auth_context()["x509_common_name"][0].decode("utf-8") @@ -324,13 +321,13 @@ def get_server(self): with open(self.root_certificate, "rb") as f: root_certificate_b = f.read() - if self.disable_client_auth: + if not self.require_client_auth: self.logger.warning("Client-side authentication is disabled.") self.server_credentials = ssl_server_credentials( ((private_key_b, certificate_b),), root_certificates=root_certificate_b, - require_client_auth=not self.disable_client_auth, + require_client_auth=self.require_client_auth, ) self.server.add_secure_port(self.uri, self.server_credentials)