diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 01b22c7b7c6d..27f2690cf97d 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -22,6 +22,7 @@ from uuid import UUID import grpc +from google.protobuf.message import Message as GrpcMessage from flwr.common.constant import Status from flwr.common.logger import log @@ -212,7 +213,7 @@ def PullServerAppInputs( # Lock access to LinkState, preventing obtaining the same pending run_id with self.lock: # If run_id is provided, use it, otherwise use the pending run_id - if request.HasField("run_id"): + if _has_field(request, "run_id"): run_id: Optional[int] = request.run_id else: run_id = state.get_pending_run_id() @@ -256,3 +257,8 @@ def PushServerAppOutputs( def _raise_if(validation_error: bool, detail: str) -> None: if validation_error: raise ValueError(f"Malformed PushTaskInsRequest: {detail}") + + +def _has_field(message: GrpcMessage, field_name: str) -> bool: + """Check if a certain field is set for the message, including scalar fields.""" + return field_name in {fld.name for fld, _ in message.ListFields()}